KNN 算法-实战篇-如何识别手写数字

2021-02-24  本文已影响0人  码农充电站pro

上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字

1,手写数字数据集

手写数字数据集是一个用于图像处理的数据集,这些数据描绘了 [0, 9] 的数字,我们可以用KNN 算法来识别这些数字。

MNIST 是完整的手写数字数据集,其中包含了60000 个训练样本和10000 个测试样本。

sklearn 中也有一个自带的手写数字数据集

我们抽出 5 个样本来看下:

0,0,5,13,9,1,0,0,0,0,13,15,10,15,5,0,0,3,15,2,0,11,8,0,0,4,12,0,0,8,8,0,0,5,8,0,0,9,8,0,0,4,11,0,1,12,7,0,0,2,14,5,10,12,0,0,0,0,6,13,10,0,0,0,0
0,0,0,12,13,5,0,0,0,0,0,11,16,9,0,0,0,0,3,15,16,6,0,0,0,7,15,16,16,2,0,0,0,0,1,16,16,3,0,0,0,0,1,16,16,6,0,0,0,0,1,16,16,6,0,0,0,0,0,11,16,10,0,0,1
0,0,0,4,15,12,0,0,0,0,3,16,15,14,0,0,0,0,8,13,8,16,0,0,0,0,1,6,15,11,0,0,0,1,8,13,15,1,0,0,0,9,16,16,5,0,0,0,0,3,13,16,16,11,5,0,0,0,0,3,11,16,9,0,2
0,0,7,15,13,1,0,0,0,8,13,6,15,4,0,0,0,2,1,13,13,0,0,0,0,0,2,15,11,1,0,0,0,0,0,1,12,12,1,0,0,0,0,0,1,10,8,0,0,0,8,4,5,14,9,0,0,0,7,13,13,9,0,0,3
0,0,0,1,11,0,0,0,0,0,0,7,8,0,0,0,0,0,1,13,6,2,2,0,0,0,7,15,0,9,8,0,0,5,16,10,0,16,6,0,0,4,15,16,13,16,1,0,0,0,0,3,15,10,0,0,0,0,0,2,16,4,0,0,4

使用该数据集,需要先加载:

>>> from sklearn.datasets import load_digits
>>> digits = load_digits()

查看第一个图像数据:

>>> digits.images[0]
array([[ 0.,  0.,  5., 13.,  9.,  1.,  0.,  0.],
       [ 0.,  0., 13., 15., 10., 15.,  5.,  0.],
       [ 0.,  3., 15.,  2.,  0., 11.,  8.,  0.],
       [ 0.,  4., 12.,  0.,  0.,  8.,  8.,  0.],
       [ 0.,  5.,  8.,  0.,  0.,  9.,  8.,  0.],
       [ 0.,  4., 11.,  0.,  1., 12.,  7.,  0.],
       [ 0.,  2., 14.,  5., 10., 12.,  0.,  0.],
       [ 0.,  0.,  6., 13., 10.,  0.,  0.,  0.]])

我们可以用 matplotlib 将该图像画出来:

>>> import matplotlib.pyplot as plt
>>> plt.imshow(digits.images[0])
>>> plt.show()

画出来的图像如下,代表 0

2,sklearn 对 KNN 算法的实现

sklearn 库的 neighbors 模块实现了KNN 相关算法,其中:

这两个类的构造方法基本一致,这里我们主要介绍 KNeighborsClassifier 类,原型如下:

KNeighborsClassifier(
    n_neighbors=5, 
    weights='uniform', 
    algorithm='auto', 
    leaf_size=30, 
    p=2, 
    metric='minkowski', 
    metric_params=None, 
    n_jobs=None, 
    **kwargs)

来看下几个重要参数的含义:

3,构造 KNN 分类器

首先加载数据集:

from sklearn.datasets import load_digits

digits = load_digits()
data = digits.data     # 特征集
target = digits.target # 目标集

将数据集拆分为训练集(75%)和测试集(25%):

from sklearn.model_selection import train_test_split

train_x, test_x, train_y, test_y = train_test_split(
    data, target, test_size=0.25, random_state=33)

构造KNN 分类器:

from sklearn.neighbors import KNeighborsClassifier

# 采用默认参数
knn = KNeighborsClassifier() 

拟合模型:

knn.fit(train_x, train_y) 

预测数据:

predict_y = knn.predict(test_x) 

计算模型准确度:

from sklearn.metrics import accuracy_score

score = accuracy_score(test_y, predict_y)
print score # 0.98

最终计算出来模型的准确度是 98%,准确度还是不错的。

4,总结

本篇文章使用KNN 算法处理了一个实际的分类问题,主要介绍了以下几点:

(本节完。)


推荐阅读:

KNN 算法-理论篇-如何给电影进行分类

决策树算法-理论篇-如何计算信息纯度

决策树算法-实战篇-鸢尾花及波士顿房价预测

朴素贝叶斯分类-理论篇-如何通过概率解决分类问题

朴素贝叶斯分类-实战篇-如何进行文本分类

上一篇 下一篇

猜你喜欢

热点阅读