《机器学习实战》kNN
2017-10-21 本文已影响9人
zqyadam
k-邻近算法
基本样例
# 引入基本模块
from numpy import *
import operator
# 定义数据
def createDataSet():
group = array([[1.,1.1],[1.,1.],[0,0],[0,0.1]])
labels = ['A','A','B','B']
return group, labels
# 创建分类器
def classify0(inX, dataset, labels, k):
dataSetSize = dataset.shape[0]
diffMat = tile(inX, (dataSetSize,1)) - dataset
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distance = sqDistances**0.5
# print '距离\n',sqDistances
sortedDistIndicies = distance.argsort()
# print '距离排序\n', sortedDistIndicies
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0)+1
# print '各类别数量\n', classCount.items()
sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)
# print '各类别数量排序\n', sortedClassCount
return sortedClassCount[0][0]
# 测试
group, labels = createDataSet()
finalClass = classify0([1,0.8], group, labels, 3)
print 'final class:\n', finalClass
final class:
A
约会网站示例
# 定义读取文件函数
def file2matrix(filename):
f = open(filename)
arrayLines = f.readlines()
numberOfLines = len(arrayLines)
returnMat = zeros((numberOfLines,3))
classLabelVector = []
index = 0
for line in arrayLines:
line = line.strip()
listLine = line.split('\t')
returnMat[index,:] = listLine[:3]
classLabelVector.append(int(listLine[-1]))
index += 1
f.close()
return returnMat, classLabelVector
dateDataMat, dateLabels = file2matrix('datingTestSet2.txt')
# 绘制散点图
import matplotlib.pyplot as plt
# 游戏时间与飞行里程关系
fig1 = plt.figure()
ax1 = fig1.add_subplot(111)
ax1.scatter(dateDataMat[:,0],dateDataMat[:,1],15*array(dateLabels),15*array(dateLabels))
ax1.set_title('scatter1')
plt.xlabel('fly miles')
plt.ylabel('game time percent')
# 游戏时间与冰淇淋消耗公升数
fig2 = plt.figure()
ax2 = fig2.add_subplot(111)
ax2.scatter(dateDataMat[:,1],dateDataMat[:,2],15*array(dateLabels),15*array(dateLabels))
ax2.set_title('scatter2')
plt.xlabel('game time percent')
plt.ylabel('ice cream')
# 飞行里程数与冰淇淋消耗公升数
fig3 = plt.figure()
ax3 = fig3.add_subplot(111)
ax3.scatter(dateDataMat[:,0],dateDataMat[:,2],15*array(dateLabels),15*array(dateLabels))
ax3.set_title('scatter3')
plt.xlabel('fly miles')
plt.ylabel('ice cream')
plt.show()
output_8_0.png
output_8_1.png
output_8_2.png
# 归一化
def autoNorm(dataSet):
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
l = dataSet.shape[0]
normDataSet = (dataSet - tile(minVals, (l,1)))/tile(ranges, (l,1))
return normDataSet
print autoNorm(dateDataMat)
[[ 0.44832535 0.39805139 0.56233353]
[ 0.15873259 0.34195467 0.98724416]
[ 0.28542943 0.06892523 0.47449629]
...,
[ 0.29115949 0.50910294 0.51079493]
[ 0.52711097 0.43665451 0.4290048 ]
[ 0.47940793 0.3768091 0.78571804]]
# 进行测试
def dateTest():
ratio = 0.1
dateDataMat, dataLabels = file2matrix('datingTestSet2.txt')
normMat = autoNorm(dateDataMat)
rows = normMat.shape[0]
numTest = int(rows*ratio)
errorCount = 0
for i in range(numTest):
classifierResult = classify0(normMat[i,:], normMat[numTest:,:], dataLabels[numTest:],3)
print "分类器返回值:%d, 真实分类:%d"%(classifierResult, dataLabels[i])
if classifierResult != dateLabels[i]: errorCount +=1
print '错误率: ', errorCount / float(numTest)
dateTest()
分类器返回值:3, 真实分类:3
分类器返回值:2, 真实分类:2
分类器返回值:1, 真实分类:1
分类器返回值:1, 真实分类:1
分类器返回值:1, 真实分类:1
分类器返回值:1, 真实分类:1
分类器返回值:3, 真实分类:3
分类器返回值:3, 真实分类:3
(...略)
分类器返回值:3, 真实分类:3
分类器返回值:2, 真实分类:2
分类器返回值:1, 真实分类:1
分类器返回值:3, 真实分类:1
错误率: 0.05
手写识别系统
# 读入文件并转化为1×1024的数组
def img2vector(filename):
returnVec = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVec[0,32*i+j] = int(lineStr[j])
return returnVec
import os
# 读取文件集,返回输入和label
def getFileSet(filePath):
labels = []
fileList = os.listdir(filePath)
fileListLen = len(fileList)
fileSet = zeros((fileListLen,1024))
for i in range(fileListLen):
f = fileList[i]
fileStr = f.split('.')[0]
labels.append(int(fileStr.split('_')[0]))
fileSet[i,:] = img2vector(filePath +'/'+f)
return fileSet, labels, fileList
# 测试
def handwritingTest():
filePath = 'digits/testDigits'
testSet, testLabels, testFileList = getFileSet(filePath)
trainSet, trainLabels,trainFileList = getFileSet('digits/trainingDigits')
testSetLen = testSet.shape[0]
errorCount = 0
for i in range(testSetLen):
classifierResult = classify0(testSet[i,:], trainSet, trainLabels, 3)
print '%s:分类器分类结果:%d, 真实分类:%d'%(testFileList[i], classifierResult, testLabels[i])
if classifierResult != testLabels[i] : errorCount +=1
print '错误率:%f'%(errorCount /float(testSetLen))
handwritingTest()
0_0.txt:分类器分类结果:0, 真实分类:0
0_1.txt:分类器分类结果:0, 真实分类:0
0_10.txt:分类器分类结果:0, 真实分类:0
(...略)
0_81.txt:分类器分类结果:0, 真实分类:0
0_82.txt:分类器分类结果:0, 真实分类:0
0_83.txt:分类器分类结果:0, 真实分类:0
0_84.txt:分类器分类结果:0, 真实分类:0
0_85.txt:分类器分类结果:0, 真实分类:0
0_86.txt:分类器分类结果:0, 真实分类:0
0_9.txt:分类器分类结果:0, 真实分类:0
1_0.txt:分类器分类结果:1, 真实分类:1
1_1.txt:分类器分类结果:1, 真实分类:1
1_10.txt:分类器分类结果:1, 真实分类:1
1_11.txt:分类器分类结果:1, 真实分类:1
1_12.txt:分类器分类结果:1, 真实分类:1
(...略)
1_90.txt:分类器分类结果:1, 真实分类:1
1_91.txt:分类器分类结果:1, 真实分类:1
1_92.txt:分类器分类结果:1, 真实分类:1
1_93.txt:分类器分类结果:1, 真实分类:1
1_94.txt:分类器分类结果:1, 真实分类:1
1_95.txt:分类器分类结果:1, 真实分类:1
1_96.txt:分类器分类结果:1, 真实分类:1
2_0.txt:分类器分类结果:2, 真实分类:2
2_1.txt:分类器分类结果:2, 真实分类:2
2_10.txt:分类器分类结果:2, 真实分类:2
2_11.txt:分类器分类结果:2, 真实分类:2
2_12.txt:分类器分类结果:2, 真实分类:2
2_13.txt:分类器分类结果:2, 真实分类:2
(...略)
2_85.txt:分类器分类结果:2, 真实分类:2
2_86.txt:分类器分类结果:2, 真实分类:2
2_87.txt:分类器分类结果:2, 真实分类:2
2_88.txt:分类器分类结果:2, 真实分类:2
2_89.txt:分类器分类结果:2, 真实分类:2
2_9.txt:分类器分类结果:2, 真实分类:2
2_90.txt:分类器分类结果:2, 真实分类:2
2_91.txt:分类器分类结果:2, 真实分类:2
3_0.txt:分类器分类结果:3, 真实分类:3
3_1.txt:分类器分类结果:3, 真实分类:3
3_10.txt:分类器分类结果:3, 真实分类:3
3_11.txt:分类器分类结果:9, 真实分类:3
3_12.txt:分类器分类结果:3, 真实分类:3
3_13.txt:分类器分类结果:3, 真实分类:3
3_14.txt:分类器分类结果:3, 真实分类:3
(...略)
3_82.txt:分类器分类结果:3, 真实分类:3
3_83.txt:分类器分类结果:3, 真实分类:3
3_84.txt:分类器分类结果:3, 真实分类:3
3_9.txt:分类器分类结果:3, 真实分类:3
4_0.txt:分类器分类结果:4, 真实分类:4
4_1.txt:分类器分类结果:4, 真实分类:4
4_10.txt:分类器分类结果:4, 真实分类:4
4_100.txt:分类器分类结果:4, 真实分类:4
4_101.txt:分类器分类结果:4, 真实分类:4
4_102.txt:分类器分类结果:4, 真实分类:4
4_103.txt:分类器分类结果:4, 真实分类:4
4_104.txt:分类器分类结果:4, 真实分类:4
(...略)
4_95.txt:分类器分类结果:4, 真实分类:4
4_96.txt:分类器分类结果:4, 真实分类:4
4_97.txt:分类器分类结果:4, 真实分类:4
4_98.txt:分类器分类结果:4, 真实分类:4
4_99.txt:分类器分类结果:4, 真实分类:4
5_0.txt:分类器分类结果:5, 真实分类:5
5_1.txt:分类器分类结果:5, 真实分类:5
5_10.txt:分类器分类结果:5, 真实分类:5
5_100.txt:分类器分类结果:5, 真实分类:5
5_101.txt:分类器分类结果:5, 真实分类:5
(...略)
5_92.txt:分类器分类结果:5, 真实分类:5
5_93.txt:分类器分类结果:5, 真实分类:5
5_94.txt:分类器分类结果:5, 真实分类:5
5_95.txt:分类器分类结果:5, 真实分类:5
5_96.txt:分类器分类结果:5, 真实分类:5
5_97.txt:分类器分类结果:5, 真实分类:5
5_98.txt:分类器分类结果:5, 真实分类:5
5_99.txt:分类器分类结果:5, 真实分类:5
6_0.txt:分类器分类结果:6, 真实分类:6
6_1.txt:分类器分类结果:6, 真实分类:6
6_10.txt:分类器分类结果:6, 真实分类:6
6_11.txt:分类器分类结果:6, 真实分类:6
6_12.txt:分类器分类结果:6, 真实分类:6
6_13.txt:分类器分类结果:6, 真实分类:6
(...略)
6_82.txt:分类器分类结果:6, 真实分类:6
6_83.txt:分类器分类结果:6, 真实分类:6
6_84.txt:分类器分类结果:6, 真实分类:6
6_85.txt:分类器分类结果:6, 真实分类:6
6_86.txt:分类器分类结果:6, 真实分类:6
6_9.txt:分类器分类结果:6, 真实分类:6
7_0.txt:分类器分类结果:7, 真实分类:7
7_1.txt:分类器分类结果:7, 真实分类:7
7_10.txt:分类器分类结果:7, 真实分类:7
7_11.txt:分类器分类结果:7, 真实分类:7
7_12.txt:分类器分类结果:7, 真实分类:7
7_13.txt:分类器分类结果:7, 真实分类:7
7_14.txt:分类器分类结果:7, 真实分类:7
(...略)
7_91.txt:分类器分类结果:7, 真实分类:7
7_92.txt:分类器分类结果:7, 真实分类:7
7_93.txt:分类器分类结果:7, 真实分类:7
7_94.txt:分类器分类结果:7, 真实分类:7
7_95.txt:分类器分类结果:7, 真实分类:7
8_0.txt:分类器分类结果:8, 真实分类:8
8_1.txt:分类器分类结果:8, 真实分类:8
8_10.txt:分类器分类结果:8, 真实分类:8
8_11.txt:分类器分类结果:6, 真实分类:8
8_12.txt:分类器分类结果:8, 真实分类:8
8_13.txt:分类器分类结果:8, 真实分类:8
8_14.txt:分类器分类结果:8, 真实分类:8
(...略)
8_83.txt:分类器分类结果:8, 真实分类:8
8_84.txt:分类器分类结果:8, 真实分类:8
8_85.txt:分类器分类结果:8, 真实分类:8
8_86.txt:分类器分类结果:8, 真实分类:8
8_87.txt:分类器分类结果:8, 真实分类:8
8_88.txt:分类器分类结果:8, 真实分类:8
8_89.txt:分类器分类结果:8, 真实分类:8
8_9.txt:分类器分类结果:8, 真实分类:8
8_90.txt:分类器分类结果:8, 真实分类:8
9_0.txt:分类器分类结果:9, 真实分类:9
9_1.txt:分类器分类结果:9, 真实分类:9
9_10.txt:分类器分类结果:9, 真实分类:9
9_11.txt:分类器分类结果:9, 真实分类:9
9_12.txt:分类器分类结果:9, 真实分类:9
9_13.txt:分类器分类结果:9, 真实分类:9
9_14.txt:分类器分类结果:1, 真实分类:9
(...略)
9_83.txt:分类器分类结果:9, 真实分类:9
9_84.txt:分类器分类结果:9, 真实分类:9
9_85.txt:分类器分类结果:9, 真实分类:9
9_86.txt:分类器分类结果:9, 真实分类:9
9_87.txt:分类器分类结果:9, 真实分类:9
9_88.txt:分类器分类结果:9, 真实分类:9
9_9.txt:分类器分类结果:9, 真实分类:9
错误率:0.011628