计算机视觉---手写体识别,K-最临近分类
2017-09-14 本文已影响0人
sdnjyxr
参考了前辈的文章,对MNIST库的手写体进行训练,可能是中外手写习惯的不同,准确率不是很高,建议自己搜集数据库,可以在周围同学处搜集素材,本文对MNIST中的digits.png划分为5000个样本进行训练。
效果图
S.1 对digits.png进行划分,得到训练的数据集
def initKnn():
knn = cv2.ml.KNearest_create() #建立knn模型
img = cv2.imread('digits.png') #获取数据集
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) #转灰度图
cells = [np.hsplit(row,100) for row in np.vsplit(gray,50)] #划分图片
train = np.array(cells).reshape(-1,400).astype(np.float32) #将图片转为行向量
trainLabel = np.repeat(np.arange(10),500) #创建索引
return knn, train, trainLabel
S.2 对图片进行形态学处理,并边缘检测,框选数字所在区域,并显示响应预测数值
def findRoi(frame, thresValue):
rois = []
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
gray2 = cv2.dilate(gray,None,iterations=2) #膨胀两次
gray2 = cv2.erode(gray2,None,iterations=2) #腐蚀两次
edges = cv2.absdiff(gray,gray2) #做差,创建sobel算子边缘检测
x = cv2.Sobel(edges,cv2.CV_16S,1,0)
y = cv2.Sobel(edges,cv2.CV_16S,0,1)
absX = cv2.convertScaleAbs(x)
absY = cv2.convertScaleAbs(y)
dst = cv2.addWeighted(absX,0.5,absY,0.5,0)
ret, ddst = cv2.threshold(dst,thresValue,255,cv2.THRESH_BINARY) #二值化
im, contours, hierarchy = cv2.findContours(ddst,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) #寻找边界
for c in contours:
x, y, w, h = cv2.boundingRect(c)
if w > 10 and h > 20:
rois.append((x,y,w,h))
digits = []
for r in rois:
x, y, w, h = r
digit, th = findDigit(knn, edges[y:y+h,x:x+w], 50)
digits.append(cv2.resize(th,(20,20)))
cv2.rectangle(frame, (x,y), (x+w,y+h), (153,153,0), 2) #绘制矩形框
cv2.putText(frame, str(digit), (x,y), cv2.FONT_HERSHEY_SIMPLEX, 1, (127,0,255), 2)
return edges
S.3 对检测到的ROI区域进行预测,输出预测值
def findDigit(knn, roi, thresValue):
ret, th = cv2.threshold(roi, thresValue, 255, cv2.THRESH_BINARY)
th = cv2.resize(th,(20,20))
out = th.reshape(-1,400).astype(np.float32)
ret, result, neighbours, dist = knn.findNearest(out, k=5)
return int(result[0][0]), th
完整代码如下
#!/usr/bin/python3
# -*- coding: UTF-8 -*-
import cv2
import numpy as np
def initKnn():
knn = cv2.ml.KNearest_create()
img = cv2.imread('digits.png')
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
cells = [np.hsplit(row,100) for row in np.vsplit(gray,50)]
train = np.array(cells).reshape(-1,400).astype(np.float32)
trainLabel = np.repeat(np.arange(10),500)
return knn, train, trainLabel
def findRoi(frame, thresValue):
rois = []
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
gray2 = cv2.dilate(gray,None,iterations=2)
gray2 = cv2.erode(gray2,None,iterations=2)
edges = cv2.absdiff(gray,gray2)
x = cv2.Sobel(edges,cv2.CV_16S,1,0)
y = cv2.Sobel(edges,cv2.CV_16S,0,1)
absX = cv2.convertScaleAbs(x)
absY = cv2.convertScaleAbs(y)
dst = cv2.addWeighted(absX,0.5,absY,0.5,0)
ret, ddst = cv2.threshold(dst,thresValue,255,cv2.THRESH_BINARY)
im, contours, hierarchy = cv2.findContours(ddst,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for c in contours:
x, y, w, h = cv2.boundingRect(c)
if w > 10 and h > 20:
rois.append((x,y,w,h))
digits = []
for r in rois:
x, y, w, h = r
digit, th = findDigit(knn, edges[y:y+h,x:x+w], 50)
digits.append(cv2.resize(th,(20,20)))
cv2.rectangle(frame, (x,y), (x+w,y+h), (153,153,0), 2)
cv2.putText(frame, str(digit), (x,y), cv2.FONT_HERSHEY_SIMPLEX, 1, (127,0,255), 2)
return edges
def findDigit(knn, roi, thresValue):
ret, th = cv2.threshold(roi, thresValue, 255, cv2.THRESH_BINARY)
th = cv2.resize(th,(20,20))
out = th.reshape(-1,400).astype(np.float32)
ret, result, neighbours, dist = knn.findNearest(out, k=5)
return int(result[0][0]), th
knn, train, trainLabel = initKnn()
knn.train(train,cv2.ml.ROW_SAMPLE,trainLabel)
cap = cv2.VideoCapture(0)
width = 426*2
height = 480
videoFrame = cv2.VideoWriter('frame.avi',cv2.VideoWriter_fourcc('M','J','P','G'),25,(int(width),int(height)),True)
while True:
ret, frame = cap.read()
frame = frame[:,:426]
edges = findRoi(frame, 50)
newEdges = cv2.cvtColor(edges, cv2.COLOR_GRAY2BGR)
newFrame = np.hstack((frame,newEdges))
cv2.imshow('frame', newFrame)
videoFrame.write(newFrame)
key = cv2.waitKey(1) & 0xff
if key == ord(' '):
break
参考原文:http://blog.csdn.net/littlethunder/article/details/51615237