KNN算法应用

2020-08-16  本文已影响0人  大板锹

KNN算法原理

KNN的全称是K Nearest Neighbors,意思是K个最近的邻居,从这个名字我们就能看出一些KNN算法的蛛丝马迹了。K个最近邻居,毫无疑问,K的取值肯定是至关重要的。那么最近的邻居又是怎么回事呢?其实啊,KNN的原理就是当预测一个新的值x的时候,根据它距离最近的K个点是什么类别来判断x属于哪个类别。

KNN算法实现步骤

  1. 处理数据
  2. 数据向量化
  3. 计算欧几里得距离
  4. 根据距离进行分类

KNN算法实例(手写数字识别)

可应用于识别图片验证码,车牌号照片识别等场景。

导入需要的包

from numpy import *
import operator
from os import listdir

编写knn函数

#(个数,测试集,训练集,类别)
#从列方向扩展
#tile(a,(size,1))
def knn(k,testdata,traindata,labels):
    traindatasize = traindata.shape[0]
    dif = tile(testdata,(traindatasize,1))-traindata
    sqdif = dif ** 2
    sumsqdif = sqdif.sum(axis=1)
    distance = sumsqdif ** 0.5
    sortdistance = distance.argsort()#对下标排序
    count = {}
    for i in range(0,k):
        vote = labels[sortdistance[i]]
        count[vote] = count.get(vote,0) + 1
    sortcount = sorted(count.items(),key=operator.itemgetter(1),reverse=True)
    return sortcount[0][0]

对手写图片进行处理,将图片编为txt文件,空白地方用‘0’表示,数字地方用‘1’表示

from PIL import Image
im=Image.open("D:\DataFrog\shujuwajue\suanfa/a.jpg")
fh=open("D:\DataFrog\shujuwajue\suanfa/a.txt","a")
width=im.size[0]
height=im.size[1]
for i in range(0,width):
    for j in range(0,height):
        cl=im.getpixel((i,j))
        clall=cl[0]+cl[1]+cl[2]
        if(clall==0):
            #黑色
            fh.write("1")
        else:
            fh.write("0")
    fh.write("\n")
fh.close()
手写图片.png
转为文本.png

加载数据(32是处理图片时按照32*32像素处理形成的矩阵大小)

def datatoarray(fname):
    arr=[]
    fh=open(fname)
    for i in range(0,32):
        thisline=fh.readline()
        for j in range(0,32):
            arr.append(int(thisline[j]))
    return arr

建立一个函数取文件名前缀

def traindata():
    labels=[]
    trainfile=listdir("D:\DataFrog\shujuwajue\suanfa/traindata")
    num=len(trainfile)
    #长度1024(列),每一行存储一个文件
    #用一个数组存储所有训练数据,行:文件总数,列:1024
    trainarr=zeros((num,1024))
    for i in range(0,num):
        thisfname=trainfile[i]
        thislabel=seplabel(thisfname)
        labels.append(thislabel)
        trainarr[i,:]=datatoarray("traindata/"+thisfname)
    return trainarr,labels

用测试数据调用KNN算法去测试,看是否能够准确识别

def datatest():
    trainarr,labels=traindata()
    testlist=listdir("D:\DataFrog\shujuwajue\suanfa/testdata")
    tnum=len(testlist)
    for i in range(0,tnum):
        thistestfile=testlist[i]
        testarr=datatoarray("testdata/"+thistestfile)
        rknn=knn(3,testarr,trainarr,labels)
        print(rknn)

datatest()
结果.png

取单独一个文件进行测试

trainarr,labels=traindata()
thistestfile="8_76.txt"
testarr=datatoarray("testdata/"+thistestfile)
rknn=knn(3,testarr,trainarr,labels)
print(rknn)
结果.png 文件.png
文本矩阵.png
上一篇 下一篇

猜你喜欢

热点阅读