scikit-learn_knn

2017-05-15  本文已影响38人  Ledestin

本次Demo使用scikit-learn中的knn算法来对其自带的iris鸢尾花数据集进行分类,

1.该数据集有四个特征,三个类

2.此例中scikit-learn有三个功能:

1)datasets 自带数据集:

    from sklearn import datasets

2)交叉验证的方法将数据集分为训练集和测试集:

    from sklearn.cross_validation import train_test_split

3)knn算法:

    from sklearn.neighbors import KNeighborsClassifier

Demo.py

import numpy as np
from sklearn import datasets
from sklearn.cross_validation import train_test_split
from sklearn.neighbors import KNeighborsClassifier

# 加载iris数据集
iris = datasets.load_iris()
# 读取特征
iris_X = iris.data
# 读取分类标签
iris_y = iris.target
# 将数据分为训练、测试两部分
X_train, X_test, y_train, y_test = train_test_split(iris_X, iris_y, test_size = 0.2)
# 定义分类器
knn = KNeighborsClassifier()
# 进行分类
knn.fit(X_train, y_train)
# 计算预测值
y_predict = knn.predict(X_test)
# 计算准确率, 由于每次数据集划分不同, 可能不一样
print((np.sum(np.fabs(y_predict - y_test))) / float(len(y_test)))


结果:

0.0666666666667

上一篇下一篇

猜你喜欢

热点阅读