opencv中SVM训练mnist手写体

2021-12-19  本文已影响0人  一路向后

1.源码实现

#include <iostream>
#include <string>
#include <fstream>
#include <opencv2/opencv.hpp>
#include <opencv2/ml/ml.hpp>
#include <opencv2/highgui/highgui.hpp>

using namespace std;
using namespace cv;

//小端存储转换
unsigned int reverseInt(unsigned int i)
{
    unsigned char c1, c2, c3, c4;

    c1 = i & 0xff;
    c2 = (i >> 8) & 0xff;
    c3 = (i >> 16) & 0xff;
    c4 = (i >> 24) & 0xff;

    return ((unsigned int)c1 << 24) + ((unsigned int)c2 << 16) + ((unsigned int)c3 << 8) + c4;
}

Mat read_mnist_image(const string fileName)
{
    unsigned int magic_number = 0;
    unsigned int number_of_images = 0;
    unsigned int n_rows = 0;
    unsigned int n_cols = 0;
    Mat DataMat;

    ifstream file(fileName, ios::binary);
    if(file.is_open())
    {
        cout << "成功打开图像集..." << endl;

        file.read((char *)&magic_number, sizeof(magic_number));         //幻数(文件格式)
        file.read((char *)&number_of_images, sizeof(number_of_images));     //图片总数
        file.read((char *)&n_rows, sizeof(n_rows));             //每个图像的行数
        file.read((char *)&n_cols, sizeof(n_cols));             //每个图像的列数

        magic_number = reverseInt(magic_number);
        number_of_images = reverseInt(number_of_images);
        n_rows = reverseInt(n_rows);
        n_cols = reverseInt(n_cols);

        cout << "幻数(文件格式): " << magic_number << endl;
        cout << "图片总数: " << number_of_images << endl;
        cout << "每个图像的行数: " << n_rows << endl;
        cout << "每个图像的列数: " << n_cols << endl;

        cout << "开始读取Image数据..." << endl;

        DataMat = Mat::zeros(number_of_images, n_rows*n_cols, CV_32FC1);

        for(int i=0; i<number_of_images; i++)
        {
            for(int j=0; j<n_rows*n_cols; j++)
            {
                unsigned char temp = 0;

                file.read((char *)&temp, sizeof(temp));

                float pixel = float(temp);

                DataMat.at<float>(i, j) = pixel;
            }
        }

        cout << "读取Image数据完毕..." << endl;
    }

    file.close();

    return DataMat;
}

Mat read_mnist_label(const string fileName)
{
    unsigned int magic_number = 0;
    unsigned int number_of_items = 0;
    Mat LabelMat;

    ifstream file(fileName, ios::binary);
    if(file.is_open())
    {
        cout << "成功打开标签集..." << endl;

        file.read((char *)&magic_number, sizeof(magic_number));         //幻数(文件格式)
        file.read((char *)&number_of_items, sizeof(number_of_items));       //标签总数

        magic_number = reverseInt(magic_number);
        number_of_items = reverseInt(number_of_items);

        cout << "幻数(文件格式): " << magic_number << endl;
        cout << "标签总数: " << number_of_items << endl;

        cout << "开始读取Label数据..." << endl;

        LabelMat = Mat::zeros(number_of_items, 1, CV_32SC1);

        for(int i=0; i<number_of_items; i++)
        {
            unsigned char temp = 0;

            file.read((char *)&temp, sizeof(temp));

            LabelMat.at<unsigned int>(i, 0) = (unsigned int)temp;
        }

        cout << "读取Label数据完毕..." << endl;
    }

    file.close();

    return LabelMat;
}

void get_rows_sub_mat_a(Mat &a, Mat &b, int start, int end, int step)
{
    int rownum = (end - start) / step;

    b = Mat::zeros(rownum, a.cols, CV_32FC1);

    for(int i=0; i<rownum; i++)
    {
        for(int j=0; j<a.cols; j++)
        {
            float pixel = a.at<float>(start+step*i, j);

            b.at<float>(i, j) = pixel;
        }
    }
}

void get_rows_sub_mat_b(Mat &a, Mat &b, int start, int end, int step)
{
    int rownum = (end - start) / step;

    b = Mat::zeros(rownum, a.cols, CV_32FC1);

    for(int i=0; i<rownum; i++)
    {
        unsigned int label = a.at<unsigned int>(start+step*i, 0);

        b.at<float>(i, 0) = (float)label;

        //cout << "lable: " << label << endl;
    }
}

int main()
{
    CvSVMParams params;
    CvSVM SVM;
    string train_images_path = "./train-images-idx3-ubyte";
    string train_labels_path = "./train-labels-idx1-ubyte";
    string test_images_path = "./t10k-images-idx3-ubyte";
    string test_labels_path = "./t10k-labels-idx1-ubyte";

    //set up SVM's parameters
    params.svm_type = CvSVM::C_SVC;
    params.kernel_type = CvSVM::POLY;
    params.gamma = 1.0;
    params.C = 10.0;
    params.nu = 0.5;
    params.degree = 2.10;
    //params.coef0 = 1000.0;
    //params.term_crit = cvTermCriteria(CV_TERMCRIT_EPS, 10000, FLT_EPSILON);
    //params.svm_type = CvSVM::C_SVC;
    //params.kernel_type = CvSVM::LINEAR;
    params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 10000, 1e-6);

    //读取标签数据集
    Mat train_labels = read_mnist_label(train_labels_path);
    Mat test_labels = read_mnist_label(test_labels_path);

    //读取图像数据集
    Mat train_images = read_mnist_image(train_images_path);
    Mat test_images = read_mnist_image(test_images_path);

    Mat train_labels_subs;
    Mat train_images_subs;
    Mat test_labels_subs;
    Mat test_images_subs;

    get_rows_sub_mat_b(train_labels, train_labels_subs, 0, 60000, 1);
    get_rows_sub_mat_a(train_images, train_images_subs, 0, 60000, 1);

    get_rows_sub_mat_b(test_labels, test_labels_subs, 0, 10000, 1);
    get_rows_sub_mat_a(test_images, test_images_subs, 0, 10000, 1);

    SVM.train(train_images_subs, train_labels_subs, Mat(), Mat(), params);

    SVM.save("mnist_svm.xml");

    //cout << "train end" << endl;

    int count = 0;
    for(int i = 0; i < test_images_subs.rows; i++)
    {
        Mat sample = test_images_subs.row(i);
        Mat label;
        float res = SVM.predict(sample);
        int r = 0;

        //cout << "res: " << res << " label: " << test_labels_subs.at<float>(i, 0) << endl;

        r = std::abs(res - test_labels_subs.at<float>(i, 0)) <= 0.0001 ? 1 : 0;

        count += r;
    }

    cout << "正确的识别个数 count = " << count << endl;
    cout << "错误率为..." << double(test_images_subs.rows - count) / test_images_subs.rows * 100.0 << "%....\n";

    return 0;
}

2.编译源码

$ g++ -o test test.cpp -std=c++11 -I/usr/local/include/opencv4 -L/usr/local/lib -lopencv_core -lopencv_highgui -lopencv_imgproc -lopencv_ml -Wl,-rpath=/usr/local/lib

3.运行及其结果

$ time ./test
成功打开标签集...
幻数(文件格式): 2049
标签总数: 60000
开始读取Label数据...
读取Label数据完毕...
成功打开标签集...
幻数(文件格式): 2049
标签总数: 10000
开始读取Label数据...
读取Label数据完毕...
成功打开图像集...
幻数(文件格式): 2051
图片总数: 60000
每个图像的行数: 28
每个图像的列数: 28
开始读取Image数据...
读取Image数据完毕...
成功打开图像集...
幻数(文件格式): 2051
图片总数: 10000
每个图像的行数: 28
每个图像的列数: 28
开始读取Image数据...
读取Image数据完毕...
正确的识别个数 count = 9807
错误率为...1.93%....

real    3m30.608s
user    3m29.861s
sys 0m0.164s
上一篇下一篇

猜你喜欢

热点阅读