嵌牛IT观察

k-近邻算法构建手写识别系统

2017-10-17  本文已影响0人  标准与或式

姓名:刘强
【嵌牛导读】
手写识别是计算机视觉的一个研究方向,可以看成是一个分类问题。机器学习的任务,便是解决分类(有监督学习)、聚类(无监督学习)和回归(强化学习)问题。k-近邻算法(简称kNN)是最简单的有监督学习算法,本文介绍了如何用k-近邻算法构建一个手写识别系统,并附上其python实现。
【嵌牛鼻子】
k-近邻算法 机器学习 分类 手写识别
【嵌牛提问】
k-近邻算法是什么? 如何构建一个手写识别系统?
【嵌牛正文】

k近邻算法基本思想

存在一个样本数据集,称为训练集,训练集中每个数据都存在标签(标签即数据所属的类别,从这一点可以看出,k近邻算法属于有监督学习)。对于不知道标签的新数据,将新数据的每个特征与训练集中数据对应的特征相比较,选出训练集中前k个最相似的数据(这就是k-近邻算法名称中k的出处),然后对这k个数据做统计,选择出现次数最多的标签作为新数据的标签(即k-近邻算法的输出)。
从其基本思想可以看出,k-近邻算法用于解决分类问题。所谓近邻,其实是用数据之间的欧氏距离来衡量它们的相似程度,距离越短,表示两个数据越相似。

图片来源于知乎

构建手写识别系统

需求分析

很多输入法都支持手写输入,实现手写输入通常的做法是把手写的结果生成图片,进行图像识别。我们知道,图片可以用矩阵表示,对于单通道的灰度图像,假如分辨率为32X32,则可以用一个32X32的矩阵表示,矩阵中的每个元素表示图片中该位置的像素,元素的值为0~255之间的灰度值。
而对于手写图片,表示方法则更加简单,因为手写图片是只有黑白两色的二值图像,利用图像处理软件,黑色的位置写1,白色背景写0,将其转成文本文件,如下图所示:

手写图片转成的文本文件

虽然这样表示不能有效利用内存空间(本来0/1只需占据1bit的空间,但是变成字符“0”,“1”之后需要用char类型所占的字节数),但是对于图像到矩阵的转换这一过程非常直观,方便演示。
我们的目标是:将这样的一幅“图像”输入我们的系统,我们能够输出“图像”中所显示的数字(只做数字0~9的识别)。

系统组成

我们的手写识别系统由以下部分组成:

已知标签的训练集

点此下载:用到的数据及源代码
其中,trainingDigits文件夹中存放的是用作训练集的的图片,其中包含了1934个训练样本,testDigits文件夹中存放的是用作测试集的图片,其中包含了946个测试样本。每个文件的文件名中含有它的标签。

文件输入输出模块

python读文本文件相当简单,为了迎合后续的kNN算法,我们不把图像表示成32X32的矩阵形式,而是将其转化成1X1024的向量,为此我们定义一个img2vector函数:

def img2vector(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect
kNN算法模块

根据上述对kNN算法的描述,kNN算法有如下步骤:

def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()     
    classCount={}          
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]
系统整体代码
'''
kNN: k Nearest Neighbors

Input:      inX: vector to compare to existing dataset (1xN)
            dataSet: size m data set of known vectors (NxM)
            labels: data set labels (1xM vector)
            k: number of neighbors to use for comparison (should be an odd number)
            
Output:     the most popular class label
'''
from numpy import *
import operator
from os import listdir

def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()     
    classCount={}          
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]
    
def img2vector(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect

def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir('trainingDigits')           #load the training set
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
    testFileList = listdir('testDigits')        #iterate through the test set
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
        if (classifierResult != classNumStr): errorCount += 1.0
    print("\nthe total number of errors is: %d" % errorCount)
    print("\nthe total error rate is: %f" % (errorCount/float(mTest)))

系统测试

测试环境
测试步骤
测试结果
测试结果

从测试结果来看,1.0571%的错误率,准确度还是蛮高的……

增加训练集的样本容量能有效提高系统的准确度,但是同时增加了运算量,使计算耗时增加。

上一篇下一篇

猜你喜欢

热点阅读