Apriori算法 Python 3.11 实战:从0到1实现超市购物篮分析,支持度/置信度调优

📅 2026/7/5 12:09:29
Apriori算法 Python 3.11 实战:从0到1实现超市购物篮分析,支持度/置信度调优
Apriori算法Python 3.11实战从0到1实现超市购物篮分析支持度/置信度调优超市货架上商品的摆放看似随意实则暗藏玄机。当你拿起一罐啤酒时旁边的薯片是否在向你招手这种商品之间的隐秘联系正是购物篮分析的核心。本文将带你用Python 3.11和mlxtend库从零构建完整的Apriori算法实现揭开商品关联规则的神秘面纱。1. 环境准备与数据理解在开始编码前我们需要搭建合适的工作环境。Python 3.11的性能优化特别适合处理大规模数据集这正是购物篮分析所需要的。# 安装必要的库 pip install mlxtend pandas numpy matplotlib购物篮数据通常呈现为稀疏矩阵形式每一行代表一次交易每一列表示一个商品是否被购买。以下是典型的超市交易数据示例交易ID牛奶面包啤酒尿布鸡蛋111000201110311110410010501100关键指标解释支持度(Support): 项集出现的频率如P(牛奶∩面包)置信度(Confidence): 规则X→Y的强度如P(面包|牛奶)提升度(Lift): X和Y的相关性P(Y|X)/P(Y)2. 数据预处理实战真实数据往往需要清洗才能用于分析。我们使用Python进行数据转换import pandas as pd from mlxtend.preprocessing import TransactionEncoder # 示例交易数据 dataset [[牛奶, 面包], [面包, 啤酒, 尿布], [牛奶, 面包, 啤酒, 尿布], [牛奶, 尿布], [面包, 啤酒]] # 转换为适合mlxtend的格式 te TransactionEncoder() te_ary te.fit(dataset).transform(dataset) df pd.DataFrame(te_ary, columnste.columns_) print(df.head())处理后的数据变为布尔矩阵True表示该商品在该次交易中被购买。这一步对后续分析至关重要因为Apriori算法的输入需要这种格式。3. Apriori算法核心实现mlxtend库提供了高效的Apriori实现我们通过调整参数来观察不同结果from mlxtend.frequent_patterns import apriori # 寻找频繁项集 frequent_itemsets apriori(df, min_support0.4, use_colnamesTrue) print(frequent_itemsets)参数调优实验 我们通过网格搜索寻找最佳支持度阈值import matplotlib.pyplot as plt supports [0.1, 0.2, 0.3, 0.4, 0.5] num_itemsets [] for s in supports: fi apriori(df, min_supports, use_colnamesTrue) num_itemsets.append(len(fi)) plt.plot(supports, num_itemsets, markero) plt.xlabel(最小支持度) plt.ylabel(频繁项集数量) plt.title(支持度阈值对结果的影响) plt.show()这个可视化清晰地展示了支持度阈值与发现的频繁项集数量之间的权衡关系——阈值越高得到的项集越少但更可靠。4. 关联规则生成与解释从频繁项集生成有意义的规则是分析的关键步骤from mlxtend.frequent_patterns import association_rules rules association_rules(frequent_itemsets, metricconfidence, min_threshold0.7) print(rules.sort_values(lift, ascendingFalse))生成的规则包含多个重要指标antecedentsconsequentssupportconfidencelift(啤酒)(面包)0.61.01.25(尿布)(面包)0.60.750.94规则解读啤酒→面包的置信度为100%但提升度仅1.25说明两者正相关但不强烈尿布→面包的提升度小于1表明两者反而有轻微排斥提示高置信度不一定代表强规则需结合提升度判断。提升度1表示正相关1表示独立1表示负相关。5. 高级应用与性能优化当处理真实超市数据时性能成为关键考量。以下是优化策略内存优化技巧# 使用稀疏矩阵处理大型数据集 from scipy.sparse import csr_matrix sparse_df csr_matrix(df.values) frequent_itemsets apriori(sparse_df, min_support0.1, use_colnamesTrue)并行计算加速# 使用joblib并行化 from joblib import Parallel, delayed def parallel_apriori(chunk): return apriori(chunk, min_support0.2) results Parallel(n_jobs4)(delayed(parallel_apriori)(chunk) for chunk in np.array_split(df, 4))FP-Growth对比 对于极大数据集FP-Growth算法效率更高from mlxtend.frequent_patterns import fpgrowth frequent_itemsets fpgrowth(df, min_support0.2, use_colnamesTrue)6. 商业决策支持应用基于分析结果我们可以制定多种商业策略商品陈列优化# 找出高提升度组合 high_lift rules[rules[lift] 2] print(建议相邻摆放的商品组合) print(high_lift[[antecedents,consequents]])促销策略制定# 找出单向强规则 one_way rules[(rules[confidence] 0.8) (rules[lift] 1.5)] print(推荐促销组合) for _, row in one_way.iterrows(): print(f主推商品{list(row[antecedents])[0]}搭售商品{list(row[consequents])[0]})库存管理应用# 预测关联商品需求 related_items {} for itemset in frequent_itemsets[itemsets]: if len(itemset) 1: key tuple(itemset) related_items[key] frequent_itemsets[frequent_itemsets[itemsets]itemset][support].values[0] print(经常一起购买的商品组及出现频率) print(related_items)7. 模型评估与验证为确保模型可靠性我们需要系统评估交叉验证设计from sklearn.model_selection import KFold kf KFold(n_splits5) stabilities [] for train_idx, test_idx in kf.split(df): train df.iloc[train_idx] test df.iloc[test_idx] train_rules association_rules(apriori(train, min_support0.3), metricconfidence, min_threshold0.6) test_rules association_rules(apriori(test, min_support0.3), metricconfidence, min_threshold0.6) # 计算规则重叠率 common set(train_rules[antecedents]).intersection(set(test_rules[antecedents])) stabilities.append(len(common)/len(train_rules)) print(f规则平均稳定性{np.mean(stabilities):.2f})指标对比表评估指标说明理想值规则稳定性交叉验证中规则的一致性0.7业务贴合度被业务专家认可的规则比例0.8预测准确率规则预测新交易的正确率0.75通过系统评估我们可以确定最佳参数组合确保模型既不过拟合又能发现真实模式。8. 可视化分析与报告生成最后我们创建专业可视化帮助决策import networkx as nx # 创建关联网络图 G nx.Graph() for _, row in rules.iterrows(): G.add_edge(list(row[antecedents])[0], list(row[consequents])[0], weightrow[lift]) plt.figure(figsize(10,8)) pos nx.spring_layout(G) nx.draw_networkx_nodes(G, pos, node_size2000, alpha0.6) nx.draw_networkx_edges(G, pos, width[d[weight] for _,_,d in G.edges(dataTrue)], alpha0.5) nx.draw_networkx_labels(G, pos, font_size12) plt.title(商品关联网络边粗细表示提升度) plt.show()热力图展示支持度与置信度import seaborn as sns # 创建规则热力图 rules[antecedent_len] rules[antecedents].apply(lambda x: len(x)) pivot rules.pivot_table(indexantecedents, columnsconsequents, valuesconfidence, aggfuncmean) plt.figure(figsize(12,8)) sns.heatmap(pivot, annotTrue, fmt.2f, cmapYlOrRd) plt.title(规则置信度热力图) plt.show()这些可视化工具让复杂的关联规则变得直观易懂帮助非技术人员理解分析结果。