scikit-learn_knn
2017-05-15 本文已影响38人
Ledestin
本次Demo使用scikit-learn中的knn算法来对其自带的iris鸢尾花数据集进行分类,
1.该数据集有四个特征,三个类
2.此例中scikit-learn有三个功能:
1)datasets 自带数据集:
from sklearn import datasets
2)交叉验证的方法将数据集分为训练集和测试集:
from sklearn.cross_validation import train_test_split
3)knn算法:
from sklearn.neighbors import KNeighborsClassifier
Demo.py
import numpy as np
from sklearn import datasets
from sklearn.cross_validation import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 加载iris数据集
iris = datasets.load_iris()
# 读取特征
iris_X = iris.data
# 读取分类标签
iris_y = iris.target
# 将数据分为训练、测试两部分
X_train, X_test, y_train, y_test = train_test_split(iris_X, iris_y, test_size = 0.2)
# 定义分类器
knn = KNeighborsClassifier()
# 进行分类
knn.fit(X_train, y_train)
# 计算预测值
y_predict = knn.predict(X_test)
# 计算准确率, 由于每次数据集划分不同, 可能不一样
print((np.sum(np.fabs(y_predict - y_test))) / float(len(y_test)))
结果:
0.0666666666667