02-KNN案例:鸢尾花分类

📅 2026/7/5 13:52:32
02-KNN案例:鸢尾花分类
1. 需求说明基于鸢尾花数据集使用 KNN 实现 3 类鸢尾花自动分类。2. 数据集简介来源经典公开数据集sklearn 内置。样本共 150 条分 3 类鸢尾花山鸢尾、变色鸢尾、维吉尼亚鸢尾每类各 50 条。特征4 个数值特征花萼长、花萼宽、花瓣长、花瓣宽。sepal length (cm)花萼长度sepal width (cm)花萼宽度petal length (cm)花瓣长度petal width (cm)花瓣宽度标签3 种花卉类别无缺失、无异常值数据干净适合 KNN 等基础分类算法测试。0setosa 山鸢尾1versicolor 变色鸢尾2virginica 维吉尼亚鸢尾3. KNN算法实现3.1 导包from sklearn.datasets import load_iris # 导入数据集 from sklearn.model_selection import train_test_split, GridSearchCV # 划分数据集, 交叉验证网格搜索 from sklearn.neighbors import KNeighborsClassifier # 创建模型对象 from sklearn.metrics import accuracy_score # 模型评估 from sklearn.preprocessing import StandardScaler # 标准化对象 import seaborn as sns import pandas as pd import matplotlib.pyplot as plt3.2 获取数据# 1. 加载数据集 data load_iris() # 2. 查看数据集 # print(data.keys()) # 获取数据集所有的key print(data.data[:5]) # 数据集的前5行 print(data.target[:5]) # 数据集的标签的前5行 print(data.target_names) # 数据集的标签名 print(data.feature_names) # 数据集的属性名 # print(data.DESCR) # 数据集的描述 # print(data.frame) # 数据集的DataFrame格式 # print(data.filename) # 数据集的文件名 # print(data.data_module) # 数据集的模块 # 3. 可视化 data_df pd.DataFrame(data.data, columnsdata.feature_names) data_df[target] data.target # 散点图 sns.lmplot(xsepal length (cm), ysepal width (cm), huetarget, datadata_df, fit_regFalse) plt.title(iris data) plt.tight_layout() # 自动调整子图参数使之填充整个图像 plt.show()3.3 数据预处理缺失值、异常值处理这里不用划分数据集# 划分数据集 # 参数解释 # test_size测试集所占的比例默认为0.25 # random_state随机数种子可以指定一个整数从而保证每次运行时数据集的划分都是固定的 # 返回值x_train, x_test, y_train, y_test x_train, x_test, y_train, y_test train_test_split(data.data, data.target, test_size0.2, random_state22) # 打印划分后的数据集 print(f训练集特征的大小 {x_train.shape}) print(f测试集特征的大小 {x_test.shape}) print(f训练集标签的大小 {y_train.shape}) print(f测试集标签的大小 {y_test.shape})3.4 特征工程3.4.1 特征提取这里不用3.4.2 特征预处理标准化scaler StandardScaler() x_train scaler.fit_transform(x_train) # 方法解释fit_transform()兼具训练和转换的功能先训练再转换适用于第一次进行标准化一般用于训练集 x_test scaler.transform(x_test) # 方法解释transform()只转换适用于已经训练的模型进行转换3.5 模型训练网格搜索交叉验证# 创建模型对象 estimator KNeighborsClassifier() # 网格搜索交叉验证 param_dict {n_neighbors: [i for i in range(1, 11)]} # 超参数字典, 超参可能出现的值 estimator GridSearchCV(estimator, param_gridparam_dict, cv5) # 创建GridSearchCV模型对象 estimator.fit(x_train, y_train) # 交叉验证前的模型训练 # 打印结果 print(f最优评分为{estimator.best_score_}) print(f最优超参组合为{estimator.best_params_}) print(f最优的估计器对象为{estimator.best_estimator_}) print(f具体的交叉验证结果为{estimator.cv_results_})3.6 模型评估# 1获取最优超参的模型对象 estimator estimator.best_estimator_ # 2模型训练 estimator.fit(x_train, y_train) # 3模型预测 y_pre estimator.predict(x_test) # 4模型评估 print(f准确率{accuracy_score(y_test, y_pre)})3.7 模型预测# 样本外预测 x # 要预测的数据集 y_pre estimator.predict(x) print(y_pre)附件1. 完整代码这里准备放一个KNN万能建模.py文件还没整理好~