opencv中SVM训练mnist手写体
2021-12-19 本文已影响0人
一路向后
1.源码实现
#include <iostream>
#include <string>
#include <fstream>
#include <opencv2/opencv.hpp>
#include <opencv2/ml/ml.hpp>
#include <opencv2/highgui/highgui.hpp>
using namespace std;
using namespace cv;
//小端存储转换
unsigned int reverseInt(unsigned int i)
{
unsigned char c1, c2, c3, c4;
c1 = i & 0xff;
c2 = (i >> 8) & 0xff;
c3 = (i >> 16) & 0xff;
c4 = (i >> 24) & 0xff;
return ((unsigned int)c1 << 24) + ((unsigned int)c2 << 16) + ((unsigned int)c3 << 8) + c4;
}
Mat read_mnist_image(const string fileName)
{
unsigned int magic_number = 0;
unsigned int number_of_images = 0;
unsigned int n_rows = 0;
unsigned int n_cols = 0;
Mat DataMat;
ifstream file(fileName, ios::binary);
if(file.is_open())
{
cout << "成功打开图像集..." << endl;
file.read((char *)&magic_number, sizeof(magic_number)); //幻数(文件格式)
file.read((char *)&number_of_images, sizeof(number_of_images)); //图片总数
file.read((char *)&n_rows, sizeof(n_rows)); //每个图像的行数
file.read((char *)&n_cols, sizeof(n_cols)); //每个图像的列数
magic_number = reverseInt(magic_number);
number_of_images = reverseInt(number_of_images);
n_rows = reverseInt(n_rows);
n_cols = reverseInt(n_cols);
cout << "幻数(文件格式): " << magic_number << endl;
cout << "图片总数: " << number_of_images << endl;
cout << "每个图像的行数: " << n_rows << endl;
cout << "每个图像的列数: " << n_cols << endl;
cout << "开始读取Image数据..." << endl;
DataMat = Mat::zeros(number_of_images, n_rows*n_cols, CV_32FC1);
for(int i=0; i<number_of_images; i++)
{
for(int j=0; j<n_rows*n_cols; j++)
{
unsigned char temp = 0;
file.read((char *)&temp, sizeof(temp));
float pixel = float(temp);
DataMat.at<float>(i, j) = pixel;
}
}
cout << "读取Image数据完毕..." << endl;
}
file.close();
return DataMat;
}
Mat read_mnist_label(const string fileName)
{
unsigned int magic_number = 0;
unsigned int number_of_items = 0;
Mat LabelMat;
ifstream file(fileName, ios::binary);
if(file.is_open())
{
cout << "成功打开标签集..." << endl;
file.read((char *)&magic_number, sizeof(magic_number)); //幻数(文件格式)
file.read((char *)&number_of_items, sizeof(number_of_items)); //标签总数
magic_number = reverseInt(magic_number);
number_of_items = reverseInt(number_of_items);
cout << "幻数(文件格式): " << magic_number << endl;
cout << "标签总数: " << number_of_items << endl;
cout << "开始读取Label数据..." << endl;
LabelMat = Mat::zeros(number_of_items, 1, CV_32SC1);
for(int i=0; i<number_of_items; i++)
{
unsigned char temp = 0;
file.read((char *)&temp, sizeof(temp));
LabelMat.at<unsigned int>(i, 0) = (unsigned int)temp;
}
cout << "读取Label数据完毕..." << endl;
}
file.close();
return LabelMat;
}
void get_rows_sub_mat_a(Mat &a, Mat &b, int start, int end, int step)
{
int rownum = (end - start) / step;
b = Mat::zeros(rownum, a.cols, CV_32FC1);
for(int i=0; i<rownum; i++)
{
for(int j=0; j<a.cols; j++)
{
float pixel = a.at<float>(start+step*i, j);
b.at<float>(i, j) = pixel;
}
}
}
void get_rows_sub_mat_b(Mat &a, Mat &b, int start, int end, int step)
{
int rownum = (end - start) / step;
b = Mat::zeros(rownum, a.cols, CV_32FC1);
for(int i=0; i<rownum; i++)
{
unsigned int label = a.at<unsigned int>(start+step*i, 0);
b.at<float>(i, 0) = (float)label;
//cout << "lable: " << label << endl;
}
}
int main()
{
CvSVMParams params;
CvSVM SVM;
string train_images_path = "./train-images-idx3-ubyte";
string train_labels_path = "./train-labels-idx1-ubyte";
string test_images_path = "./t10k-images-idx3-ubyte";
string test_labels_path = "./t10k-labels-idx1-ubyte";
//set up SVM's parameters
params.svm_type = CvSVM::C_SVC;
params.kernel_type = CvSVM::POLY;
params.gamma = 1.0;
params.C = 10.0;
params.nu = 0.5;
params.degree = 2.10;
//params.coef0 = 1000.0;
//params.term_crit = cvTermCriteria(CV_TERMCRIT_EPS, 10000, FLT_EPSILON);
//params.svm_type = CvSVM::C_SVC;
//params.kernel_type = CvSVM::LINEAR;
params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 10000, 1e-6);
//读取标签数据集
Mat train_labels = read_mnist_label(train_labels_path);
Mat test_labels = read_mnist_label(test_labels_path);
//读取图像数据集
Mat train_images = read_mnist_image(train_images_path);
Mat test_images = read_mnist_image(test_images_path);
Mat train_labels_subs;
Mat train_images_subs;
Mat test_labels_subs;
Mat test_images_subs;
get_rows_sub_mat_b(train_labels, train_labels_subs, 0, 60000, 1);
get_rows_sub_mat_a(train_images, train_images_subs, 0, 60000, 1);
get_rows_sub_mat_b(test_labels, test_labels_subs, 0, 10000, 1);
get_rows_sub_mat_a(test_images, test_images_subs, 0, 10000, 1);
SVM.train(train_images_subs, train_labels_subs, Mat(), Mat(), params);
SVM.save("mnist_svm.xml");
//cout << "train end" << endl;
int count = 0;
for(int i = 0; i < test_images_subs.rows; i++)
{
Mat sample = test_images_subs.row(i);
Mat label;
float res = SVM.predict(sample);
int r = 0;
//cout << "res: " << res << " label: " << test_labels_subs.at<float>(i, 0) << endl;
r = std::abs(res - test_labels_subs.at<float>(i, 0)) <= 0.0001 ? 1 : 0;
count += r;
}
cout << "正确的识别个数 count = " << count << endl;
cout << "错误率为..." << double(test_images_subs.rows - count) / test_images_subs.rows * 100.0 << "%....\n";
return 0;
}
2.编译源码
$ g++ -o test test.cpp -std=c++11 -I/usr/local/include/opencv4 -L/usr/local/lib -lopencv_core -lopencv_highgui -lopencv_imgproc -lopencv_ml -Wl,-rpath=/usr/local/lib
3.运行及其结果
$ time ./test
成功打开标签集...
幻数(文件格式): 2049
标签总数: 60000
开始读取Label数据...
读取Label数据完毕...
成功打开标签集...
幻数(文件格式): 2049
标签总数: 10000
开始读取Label数据...
读取Label数据完毕...
成功打开图像集...
幻数(文件格式): 2051
图片总数: 60000
每个图像的行数: 28
每个图像的列数: 28
开始读取Image数据...
读取Image数据完毕...
成功打开图像集...
幻数(文件格式): 2051
图片总数: 10000
每个图像的行数: 28
每个图像的列数: 28
开始读取Image数据...
读取Image数据完毕...
正确的识别个数 count = 9807
错误率为...1.93%....
real 3m30.608s
user 3m29.861s
sys 0m0.164s