c++实现lenet推理模型

2024-01-13  本文已影响0人  一路向后

1.tensor.h

#ifndef _CONVNET_TENSOR_H_
#define _CONVNET_TENSOR_H_

#include <vector>

namespace convnet {

    class Tensor {
    public:
        Tensor();
        Tensor(int a);
        Tensor(int a, int b);
        Tensor(int a, int b, int c);
        ~Tensor();

        void resize(int a);
        void resize(int a, int b);
        void resize(int a, int b, int c);

        void relu();
        void sigmoid();
        void argmax(int &s);

        void set(std::vector<double> &data);
        int size();

    private:
        friend class Linear;
        friend class Conv2d;
        friend class MaxPool2d;
        friend class Reshape;

        std::vector<int> dim;
        std::vector<double> data;
    };

}

#endif

2.tensor.cpp

#include <math.h>
#include <iostream>
#include "tensor.h"

using namespace std;

convnet::Tensor::Tensor()
{
    //dim.clear();
}

convnet::Tensor::Tensor(int a)
{
    dim.resize(1);

    dim[0] = a;

    if(dim[0] > 0)
    {
        data.resize(a);
    }
}

convnet::Tensor::Tensor(int a, int b)
{
    dim.resize(2);

    dim[0] = a;
    dim[1] = b;

    if(a*b > 0)
    {
        data.resize(a*b);
    }
}

convnet::Tensor::Tensor(int a, int b, int c)
{
    dim.resize(3);

    dim[0] = a;
    dim[1] = b;
    dim[2] = c;

    if(a*b*c > 0)
    {
        data.resize(a*b*c);
    }
}

void convnet::Tensor::resize(int a)
{
    dim.resize(1);

    dim[0] = a;

    if(dim[0] > 0  && a != data.size())
    {
        data.resize(a);
    }
}

void convnet::Tensor::resize(int a, int b)
{
    dim.resize(2);

    dim[0] = a;
    dim[1] = b;

    if(a*b > 0 && a*b != data.size())
    {
        data.resize(a*b);
    }
}

void convnet::Tensor::resize(int a, int b, int c)
{
    if(dim.size() != 0)
    {
        dim.clear();
    }

    if(data.size() != 0)
    {
        data.clear();
    }

    dim.resize(3);

    dim[0] = a;
    dim[1] = b;
    dim[2] = c;

    if(a*b*c > 0)
    {
        data.resize(a*b*c);
    }
}

convnet::Tensor::~Tensor()
{
    dim.clear();
}

void convnet::Tensor::set(std::vector<double> &data)
{
    this->data = data;
}

int convnet::Tensor::size()
{
    return data.size();
}

void convnet::Tensor::relu()
{
    for(int i=0; i<data.size(); i++)
    {
        if(data[i] < 0)
        {
            data[i] = 0;
        }
    }
}

void convnet::Tensor::sigmoid()
{
    for(int i=0; i<data.size(); i++)
    {
        data[i] = 1 / (1+expf(-data[i]));
    }
}

void convnet::Tensor::argmax(int &s)
{
    if(dim.size() == 1 && dim[0] > 0)
    {
        int u = 0;
        int i = 1;

        for(i=1; i<dim[0]; i++)
        {
            if(data[i] >= data[u])
            {
                u = i;
            }
        }

        s = u;
    }
}

3.conv2d.h

#ifndef _CONVNET_CONV2D_H_
#define _CONVNET_CONV2D_H_

#include "tensor.h"

namespace convnet {

    class Conv2d {
    public:
        Conv2d(int iw, int ih, int ic, int ow, int oh, int oc, int kw, int kh);
        Conv2d(int ic, int oc, int kw);
        ~Conv2d();

        void set(Tensor *i, Tensor *o);
        void setargs(std::vector<double> &k);
        void setargs(std::vector<double> &k, std::vector<double> &b);
        void forward();
        void print(int type);

    private:
        int iw;
        int ih;
        int ic;
        int ow;
        int oh;
        int oc;
        int kw;
        int kh;
        Tensor kernel;
        Tensor bias;
        Tensor *input;
        Tensor *output;
    };

}

#endif

4.conv2d.cpp

#include <stdio.h>
#include <iostream>
#include <cassert>
#include "conv2d.h"

using namespace std;

convnet::Conv2d::Conv2d(int iw, int ih, int ic, int ow, int oh, int oc, int kw, int kh)
{
    this->iw = iw;
    this->ih = ih;
    this->ic = ic;
    this->ow = ow;
    this->oh = oh;
    this->oc = oc;
    this->kw = kw;
    this->kh = kh;
}

convnet::Conv2d::Conv2d(int ic, int oc, int kw)
{
    this->iw = 0;
    this->ih = 0;
    this->ic = ic;
    this->ow = 0;
    this->oh = 0;
    this->oc = oc;
    this->kw = kw;
    this->kh = kw;
}

convnet::Conv2d::~Conv2d()
{

}

void convnet::Conv2d::set(Tensor *i, Tensor *o)
{
    input = i;
    output = o;

    if(input->size() != iw * ih * ic)
    {
        if(iw * ih != 0)
        {
            input->resize(ic, ih, iw);
        }
        else
        {
            if(input->dim.size() == 3)
            {
                iw = input->dim[1];
                ih = input->dim[2];

                //cout << "iw: " << iw << ", " << "ih: " << ih << endl;
            }
        }
    }

    if(output->size() != ow * oh * oc || ow * oh == 0)
    {
        if(ow * oh == 0)
        {
            ow = abs(iw - kw + 1);
            oh = abs(ih - kh + 1);

            //cout << "ow: " << ow << ", " << "oh: " << oh << ", oc: " << oc << ", ic: " << ic << endl;
        }

        output->resize(oc, oh, ow);
    }
}

void convnet::Conv2d::setargs(std::vector<double> &k)
{
    if(kw * kh == k.size())
    {
        kernel.resize(kw * kh);
        kernel.set(k);
    }
}

void convnet::Conv2d::setargs(std::vector<double> &k, std::vector<double> &b)
{
    bias.dim.resize(1);

    bias.dim[0] = b.size();
    bias.data = b;

    assert(k.size() == kw * kh * ic * oc);
    assert(b.size() == oc);

    kernel.set(k);
}

void convnet::Conv2d::forward()
{
    int i, j, k, l, p, q;

    for(i=0; i<oc; i++)
    {
        for(j=0; j<oh; j++)
        {
            for(k=0; k<ow; k++)
            {
                output->data[i*oh*ow+j*ow+k] = 0.0;

                for(l=0; l<ic; l++)
                {
                    //printf("i = %d, kh = %d, kw = %d\n", i, kh, kw);

                    for(p=0; p<kh; p++)
                    {
                        for(q=0; q<kw; q++)
                        {
                            //printf("p = %d, q = %d, u = %d, k = %d\n", p, q, i*kh*kw+p*kw+q, kernel.data[i*kh*kw+p*kw+q]);

                            output->data[i*oh*ow+j*ow+k] += (kernel.data[(i*ic+l)*kh*kw+p*kw+q] * input->data[l*ih*iw+(j+p)*iw+(k+q)]);
                            //output->data[i*oh*ow+j*ow+k] += kernel.data[(l)*kh*kw+p*kw+q] * input->data[l*ih*iw+(j+p)*iw+(k+q)];
                        }
                    }

                    //output->data[i*oh*ow+j*ow+k] += bias.data[i];
                }

                output->data[i*oh*ow+j*ow+k] += bias.data[i];
            }
        }
    }
}

void convnet::Conv2d::print(int type)
{
    if(type == 0)
    {
        int m = iw * ih * ic;

        for(int i=0; i<m; i++)
        {
            printf("%.6lf ", input->data[i]);
        }

        printf("\n");
    }
    else if(type == 1)
    {
        int n = ow * oh * oc;

        for(int i=0; i<n; i++)
        {
            printf("%.6lf ", output->data[i]);
        }

        printf("\n");
    }
}

5.maxpool2d.h

#define _CONVNET_MAXPOOL2D_H_

#include "tensor.h"

namespace convnet {

    class MaxPool2d {
    public:
        MaxPool2d(int iw, int ih, int w, int h, int stride);
        MaxPool2d(int w, int h, int stride);
        MaxPool2d(int w, int stride);
        ~MaxPool2d();

        void set(Tensor *i, Tensor *o);
        void forward();
        void print(int type);

    private:
        int ic;
        int iw;
        int ih;
        int ow;
        int oh;
        int oc;
        int kw;
        int kh;
        int stride;
        Tensor *input;
        Tensor *output;
    };

}

#endif

6.maxpool2d.cpp

#include <stdio.h>
#include <iostream>
#include "maxpool2d.h"

using namespace std;

convnet::MaxPool2d::MaxPool2d(int iw, int ih, int w, int h, int stride)
{
    this->iw = iw;
    this->ih = ih;
    this->kw = kw;
    this->kh = kh;
    this->stride = stride;

    ow = iw / stride;
    oh = ih / stride;
}

convnet::MaxPool2d::MaxPool2d(int w, int h, int stride)
{
    this->iw = 0;
    this->ih = 0;
    this->ow = 0;
    this->oh = 0;
    this->kw = w;
    this->kh = h;
    this->stride = stride;
}

convnet::MaxPool2d::MaxPool2d(int w, int stride)
{
    this->iw = 0;
    this->ih = 0;
    this->ic = 0;
    this->ow = 0;
    this->oh = 0;
    this->oc = 0;
    this->kw = w;
    this->kh = w;
    this->stride = stride;
}

convnet::MaxPool2d::~MaxPool2d()
{

}

void convnet::MaxPool2d::set(Tensor *i, Tensor *o)
{
    if(iw != 0 && ih != 0)
    {
        if(iw * ih != i->size())
        {
            i->resize(iw, ih);
        }
    }
    else
    {
        if(i->dim.size() == 2)
        {
            iw = i->dim[0];
            ih = i->dim[1];
            ic = 1;
            oc = 1;
        }
        else if(i->dim.size() == 3)
        {
            ic = i->dim[0];
                        ih = i->dim[1];
                        iw = i->dim[2];
                        oc = ic;

            //cout << "dddd" << iw << ", " << ih << ", ic: " << ic << endl;
        }

        int j = 0;

        ow = iw / stride;
        oh = ih / stride;
    }

    if(ow * oh * oc != 0 && ow * oh * oc != o->size())
    {
        o->resize(oc, oh, ow);
        //cout << "eee " << ow << ", " << oh << ", oc: " << oc << endl;
    }

    input = i;
    output = o;
}

void convnet::MaxPool2d::forward()
{
    int i, j, k;
    int p, q;

    for(i=0; i<oc; i++)
    {
        for(j=0; j<oh; j++)
        {
            for(k=0; k<ow; k++)
            {
                output->data[i*oh*ow+j*ow+k] = input->data[i*ih*iw+j*stride*iw+k*stride];

                for(p=0; p<stride; p++)
                {
                    for(q=0; q<stride; q++)
                    {
                        if(input->data[i*ih*iw+(j*stride+p)*iw+(k*stride+q)] > output->data[i*oh*ow+j*ow+k])
                        {
                            output->data[i*oh*ow+j*ow+k] = input->data[i*ih*iw+(j*stride+p)*iw+(k*stride+q)];
                        }
                    }
                }
            }
        }
    }
}

void convnet::MaxPool2d::print(int type)
{
    if(type == 0)
    {
        int m = ic * ih * iw;

        for(int i=0; i<m; i++)
        {
            printf("%.6lf ", input->data[i]);
        }

        printf("\n");
    }
    else if(type == 1)
    {
        int n = oc * oh * ow;

        for(int i=0; i<n; i++)
        {
            printf("%.6lf ", output->data[i]);
        }

        printf("\n");
    }
}

7.reshape.h

#ifndef _CONVNET_RESHAPE_H_
#define _CONVNET_RESHAPE_H_

#include "tensor.h"

namespace convnet {

    class Reshape {
    public:
        Reshape(int a, int b);
        ~Reshape();

        void set(Tensor *input, Tensor *output);
        void forward();
        void print(int type);

    private:
        std::vector<int> dim;
        Tensor *input;
        Tensor *output;
    };

}

#endif

8.reshape.cpp

#include <stdio.h>
#include <math.h>
#include "reshape.h"
#include <cassert>

convnet::Reshape::Reshape(int a, int b)
{
    dim.resize(2);

    dim[0] = a;
    dim[1] = b;
}

convnet::Reshape::~Reshape()
{

}

void convnet::Reshape::set(Tensor *input, Tensor *output)
{
    if(dim.size() > 0)
    {
        if(dim[0] == -1 && dim[1] > 0)
        {
            //printf("size: %d\n", input->dim.size());
            //printf("dim: %d\n", input->dim[0] * input->dim[1]);

            if(input->dim.size() == 3)
            {
                assert(input->dim[0] * input->dim[1] * input->dim[2] == dim[1]);
            }
            else if(input->dim.size() == 2)
            {
                assert(input->dim[0] * input->dim[1] == dim[1]);
            }

            output->resize(dim[1]);
        }

        this->input = input;
        this->output = output;
    }
}

void convnet::Reshape::forward()
{
    output->data = input->data;
}

void convnet::Reshape::print(int type)
{
    if(type == 0)
    {
        int m = abs(dim[0] * dim[1]);

        //printf("input size: %p\n", input);

        for(int i=0; i<m; i++)
        {
            printf("%.6lf ", input->data[i]);
        }

        printf("\n");
    }
    else if(type == 1)
    {
        int n = abs(dim[0] * dim[1]);

        for(int i=0; i<n; i++)
        {
            printf("%.6lf ", output->data[i]);
        }

        printf("\n");
    }
}

9.linear.h

#ifndef _CONVNET_LINEAR_H_
#define _CONVNET_LINEAR_H_

#include "tensor.h"

namespace convnet {

    class Linear {
    public:
        Linear(int m, int n);

        void set(Tensor *i, Tensor *o);
        void forward();
        void setargs(std::vector<double> &w, std::vector<double> &b);
        void print(int type);

    private:
        int m;
        int n;
        Tensor weight;
        Tensor bias;
        Tensor *input;
        Tensor *output;
    };

}

#endif

10.linear.cpp

#include <stdio.h>
#include <cassert>
#include "linear.h"

convnet::Linear::Linear(int m, int n)
{
    this->m = m;
    this->n = n;

    weight.resize(m, n);
    bias.resize(n);
}

void convnet::Linear::set(Tensor *i, Tensor *o)
{
    //printf("m = %d, %d\n", m, i->size());

    if(m != i->size())
    {
        i->resize(m);
    }

    if(n != o->size())
    {
        o->resize(n);
    }

    input = i;
    output = o;
}

void convnet::Linear::setargs(std::vector<double> &w, std::vector<double> &b)
{
    assert(w.size() == m * n && b.size() == n);
    {
        weight.data = w;
        bias.data = b;
    }
}

void convnet::Linear::forward()
{
    for(int out=0; out<n; out++)
    {
        output->data[out] = 0.0;

        for(int in=0; in<m; in++)
        {
            output->data[out] += input->data[in] * weight.data[m*out + in];
        }

        output->data[out] += bias.data[out];
    }
}

void convnet::Linear::print(int type)
{
    if(type == 0)
    {
        for(int i=0; i<m; i++)
        {
            printf("%.6lf ", input->data[i]);
        }

        printf("\n");
    }
    else if(type == 1)
    {
        for(int i=0; i<n; i++)
        {
            printf("%.6lf ", output->data[i]);
        }

        printf("\n");
    }
}

11.main.cpp

#include <iostream>
#include <cstring>
#include <cassert>
#include "conv2d.h"
#include "maxpool2d.h"
#include "linear.h"
#include "reshape.h"
#include "pugixml.hpp"

using namespace std;
using namespace convnet;

void get_numbers(const std::string &line, std::vector<int> &s)
{
    char num[32];
    int len = line.length();
    int i = 0;
    int j = 0;
    int k = 0;
    int flag = 0;

    while(i < len)
    {
        flag = 0;

        if(line[i] == '+' || line[i] == '-' || (line[i] >= '0' && line[i] <= '9'))
        {
            flag = 1;

            num[j++] = line[i];
        }
        else if(isspace(line[i]))
        {

        }
        else
        {
            break;
        }

        i++;

        if(flag == 1)
        {
            while(line[i] == '+' || line[i] == '-' || (line[i] >= '0' && line[i] <= '9'))
            {
                num[j++] = line[i];

                i++;
            }

            num[j] = 0x00;

            if(i < len && line[i] == ',')
            {
                i++;
            }

            k = atoi(num);

            j = 0;

            s.push_back(k);
        }
    }
}

void get_numbers(const std::string &line, std::vector<double> &s)
{
    char num[32];
    int len = line.length();
    int i = 0;
    int j = 0;
    double k = 0;
    int flag = 0;

    while(i < len)
    {
        flag = 0;

        if(line[i] == '+' || line[i] == '-' || line[i] == 'e' || line[i] == '.' || (line[i] >= '0' && line[i] <= '9'))
        {
            flag = 1;

            num[j++] = line[i];
        }
        else if(isspace(line[i]))
        {

        }
        else
        {
            break;
        }

        i++;

        if(flag == 1)
        {
            while(line[i] == '+' || line[i] == '-' || line[i] == 'e' || line[i] == '.' || (line[i] >= '0' && line[i] <= '9'))
            {
                num[j++] = line[i];

                i++;
            }

            num[j] = 0x00;

            if(i < len && line[i] == ',')
            {
                i++;
            }

            char *endptr = NULL;

            k = strtod(num, &endptr);

            j = 0;

            s.push_back(k);
        }
    }
}

int main()
{
    Tensor *input = NULL;
    std::vector<Tensor *> outputs;
    std::vector<Conv2d *> conv2ds;
    std::vector<MaxPool2d *> maxpool2ds;
    std::vector<Reshape *> reshapes;
    std::vector<Linear *> linears;
    std::vector<int> actfuncs;
    std::vector<std::vector<int>> args;
    std::vector<int> args_brh;
    std::vector<std::string> types;
    std::vector<std::string> names;
    std::vector<int> intypes;
    std::vector<double> nums[2];
    int k = 0;
    int u = 0;
    int v = 0;
    int w = 0;
    int l = 0;

    pugi::xml_document doc;
    pugi::xml_parse_result result = doc.load_file("lenet.xml");

    if(!result)
    {
        return -1;
    }

    pugi::xml_node xnodes = doc.child("lenet");

    for(pugi::xml_node xnode = xnodes.first_child(); xnode != NULL; xnode = xnode.next_sibling())
    {
        std::string type = xnode.child("type").text().as_string();
        std::string name = xnode.child("name").text().as_string();

        //cout << "Type: " << type << endl;

        if(type == "input")
        {
            std::string value = xnode.child("value").text().as_string();

            get_numbers(value, args_brh);

            if(args_brh.size() == 3)
            {
                types.push_back(type);
                names.push_back(name);
                args.push_back(args_brh);

                if(xnode.child("actfunc") != NULL)
                {
                    std::string func = xnode.child("actfunc").text().as_string();

                    if(func == "relu")
                    {
                        actfuncs.push_back(1);
                    }
                    else
                    {
                        actfuncs.push_back(-1);
                    }
                }
                else
                {
                    actfuncs.push_back(-1);
                }
            }

            k++;

            args_brh.clear();
        }
        else if(type == "Conv2d")
        {
            std::string value = xnode.child("value").text().as_string();

            get_numbers(value, args_brh);

            if(args_brh.size() == 3)
            {
                types.push_back(type);
                names.push_back(name);
                args.push_back(args_brh);

                if(xnode.child("actfunc") != NULL)
                {
                    std::string func = xnode.child("actfunc").text().as_string();

                    if(func == "relu")
                    {
                        actfuncs.push_back(1);
                    }
                    else
                    {
                        actfuncs.push_back(-1);
                    }
                }
                else
                {
                    actfuncs.push_back(-1);
                }
            }

            k++;

            args_brh.clear();
        }
        else if(type == "MaxPool2d")
        {
            std::string value = xnode.child("value").text().as_string();

            get_numbers(value, args_brh);

            if(args_brh.size() == 2)
            {
                types.push_back(type);
                names.push_back(name);
                args.push_back(args_brh);

                if(xnode.child("actfunc") != NULL)
                {
                    std::string func = xnode.child("actfunc").text().as_string();

                    if(func == "relu")
                    {
                        actfuncs.push_back(1);
                    }
                    else
                    {
                        actfuncs.push_back(-1);
                    }
                }
                else
                {
                    actfuncs.push_back(-1);
                }
            }

            k++;

            args_brh.clear();
        }
        else if(type == "reshape")
        {
            std::string value = xnode.child("value").text().as_string();

            get_numbers(value, args_brh);

            if(args_brh.size() == 2)
            {
                types.push_back(type);
                names.push_back(name);
                args.push_back(args_brh);

                if(xnode.child("actfunc") != NULL)
                {
                    std::string func = xnode.child("actfunc").text().as_string();

                    if(func == "relu")
                    {
                        actfuncs.push_back(1);
                    }
                    else
                    {
                        actfuncs.push_back(-1);
                    }
                }
                else
                {
                    actfuncs.push_back(-1);
                }
            }

            k++;

            args_brh.clear();
        }
        else if(type == "Linear")
        {
            std::string value = xnode.child("value").text().as_string();

            get_numbers(value, args_brh);

            if(args_brh.size() == 2)
            {
                types.push_back(type);
                names.push_back(name);
                args.push_back(args_brh);

                if(xnode.child("actfunc") != NULL)
                {
                    std::string func = xnode.child("actfunc").text().as_string();

                    if(func == "relu")
                    {
                        actfuncs.push_back(1);
                    }
                    else
                    {
                        actfuncs.push_back(-1);
                    }
                }
                else
                {
                    actfuncs.push_back(-1);
                }
            }

            k++;

            args_brh.clear();
        }
    }

    for(k=0; k<types.size(); k++)
    {
        string type = types[k];

        //cout << "type: " << type << ", name: " << names[k] << endl;

        /*if(actfuncs[k] == 1)
        {
            cout << "call function relu" << endl;
        }*/

        if(type == "input")
        {
            input = new Tensor(args[k][0], args[k][1], args[k][2]);

            intypes.push_back(1);
        }
        else if(type == "Conv2d")
        {
            conv2ds.push_back(new Conv2d(args[k][0], args[k][1], args[k][2]));
            outputs.push_back(new Tensor());

            u = conv2ds.size() - 1;
            v = outputs.size() - 1;

            if(v == 0)
            {
                conv2ds[u]->set(input, outputs[v]);
            }
            else
            {
                conv2ds[u]->set(outputs[v-1], outputs[v]);
            }

            intypes.push_back(2);
        }
        else if(type == "MaxPool2d")
        {
            maxpool2ds.push_back(new MaxPool2d(args[k][0], args[k][1]));
            outputs.push_back(new Tensor());

            u = maxpool2ds.size() - 1;
            v = outputs.size() - 1;

            maxpool2ds[u]->set(outputs[v-1], outputs[v]);

            intypes.push_back(3);
        }
        else if(type == "reshape")
        {
            reshapes.push_back(new Reshape(args[k][0], args[k][1]));
            outputs.push_back(new Tensor());

            u = reshapes.size() - 1;
            v = outputs.size() - 1;

            reshapes[u]->set(outputs[v-1], outputs[v]);

            intypes.push_back(4);
        }
        else if(type == "Linear")
        {
            linears.push_back(new Linear(args[k][0], args[k][1]));
            outputs.push_back(new Tensor());

            u = linears.size() - 1;
            v = outputs.size() - 1;

            linears[u]->set(outputs[v-1], outputs[v]);

            intypes.push_back(5);
        }
    }

    FILE *fp = fopen("lenet.mod", "r");
    char line[128];
    int nodes = 0;

    if(fp == NULL)
    {
        return -1;
    }

    memset(line, 0x00, sizeof(line));

    fread(line, 1, 6, fp);
    fread(&nodes, 1, 4, fp);

    if(strcmp(line, "lenet") != 0)
    {
        return -1;
    }

    u = 0;
    v = 0;
    w = 0;

    for(k=0; k<types.size(); k++)
    {
        if(intypes[k] == 1)
        {
            //cout << "do nothing" << endl;
        }
        else if(intypes[k] == 2)
        {
            //cout << "deal conv2d" << endl;

            nums[0].clear();

            memset(line, 0x00, sizeof(line));

            fread(line, 1, names[k].length()+8, fp);
            fread(&nodes, 1, 4, fp);

            //printf("line = %s, number = %d\n", line, nodes);

            nums[0].resize(nodes);

            for(int i=0; i<nodes; i++)
            {
                fread(&nums[0][i], 1, sizeof(double), fp);

                //printf("nums = %.6lf\n", nums[0][i]);
            }

            nums[1].clear();

            memset(line, 0x00, sizeof(line));

            fread(line, 1, names[k].length()+6, fp);
            fread(&nodes, 1, 4, fp);

            //printf("line = %s, number = %d\n", line, nodes);

            nums[1].resize(nodes);

            for(int i=0; i<nodes; i++)
            {
                fread(&nums[1][i], 1, sizeof(double), fp);

                //printf("nums = %.6lf\n", nums[1][i]);
            }

            conv2ds[u++]->setargs(nums[0], nums[1]);
        }
        else if(intypes[k] == 5)
        {
            //cout << "deal linear" << endl;

            nums[0].clear();

            memset(line, 0x00, sizeof(line));

            fread(line, 1, names[k].length()+8, fp);
            fread(&nodes, 1, 4, fp);

            //printf("line = %s, number = %d\n", line, nodes);

            nums[0].resize(nodes);

            for(int i=0; i<nodes; i++)
            {
                fread(&nums[0][i], 1, sizeof(double), fp);

                //printf("nums = %.6lf\n", nums[0][i]);
            }

            nums[1].clear();

            memset(line, 0x00, sizeof(line));

            fread(line, 1, names[k].length()+6, fp);
            fread(&nodes, 1, 4, fp);

            //printf("line = %s, number = %d\n", line, nodes);

            nums[1].resize(nodes);

            for(int i=0; i<nodes; i++)
            {
                fread(&nums[1][i], 1, sizeof(double), fp);

                //printf("nums = %.6lf\n", nums[1][i]);
            }

            linears[v++]->setargs(nums[0], nums[1]);
        }
    }

    fclose(fp);

    nums[0].clear();
    nums[1].clear();

    fp = fopen("input.txt", "r");

    if(fp == NULL)
    {
        return -1;
    }

    while(!feof(fp))
    {
        memset(line, 0x00, sizeof(line));

        fgets(line, 127, fp);

        get_numbers(line, nums[0]);
    }

    fclose(fp);

    //cout << "nums = " << nums[0].size() << endl;

    assert(nums[0].size() == input->size());

    input->set(nums[0]);

    nums[0].clear();

    u = 0;
    v = 0;
    w = 0;
    l = 0;

    //printf("types.size = %d, intypes.size = %d\n", types.size(), intypes.size());

    for(k=1; k<types.size(); k++)
    {
        switch(intypes[k])
        {
            case 2:
                conv2ds[u]->forward();
                if(actfuncs[k] == 1)
                {
                    outputs[k-1]->relu();
                    //cout << "call relu" << endl;
                }
                //if(u==1)
                //conv2ds[u]->print(1);
                u++;
                break;
            case 3:
                maxpool2ds[v]->forward();
                if(actfuncs[k] == 1)
                {
                    outputs[k-1]->relu();
                }
                //if(v==1)
                //maxpool2ds[v]->print(1);
                v++;
                break;

            case 4:
                reshapes[w]->forward();
                if(actfuncs[k] == 1)
                {
                    outputs[k-1]->relu();
                }
                //cout << "reshape" << endl;
                //reshapes[w]->print(0);
                w++;
                break;

            case 5:
                linears[l]->forward();
                if(actfuncs[k] == 1)
                {
                    outputs[k-1]->relu();
                }
                //linears[l]->print(1);
                l++;
                break;
                
            default:
                break;
        }
    }

    k = types.size() - 2;

    u = 0;

    outputs[k]->argmax(u);

    printf("predict: %d\n", u);

    if(input != NULL)
    {
        delete input;

        input = NULL;
    }

    for(k=0; k<conv2ds.size(); k++)
    {
        if(conv2ds[k] != NULL)
        {
            delete conv2ds[k];

            conv2ds[k] = NULL;
        }
    }

    conv2ds.clear();

    for(k=0; k<outputs.size(); k++)
    {
        if(outputs[k] != NULL)
        {
            delete outputs[k];

            outputs[k] = NULL;
        }
    }

    outputs.clear();

    for(k=0; k<maxpool2ds.size(); k++)
    {
        if(maxpool2ds[k] != NULL)
        {
            delete maxpool2ds[k];

            maxpool2ds[k] = NULL;
        }
    }

    maxpool2ds.clear();

    for(k=0; k<reshapes.size(); k++)
    {
        if(reshapes[k] != NULL)
        {
            delete reshapes[k];

            reshapes[k] = NULL;
        }
    }

    reshapes.clear();

    for(k=0; k<linears.size(); k++)
    {
        if(linears[k] != NULL)
        {
            delete linears[k];

            linears[k] = NULL;
        }
    }

    linears.clear();

    return 0;
}

12.Makefile

CXX=g++
STD=-std=c++11
DEBUG=-g
LDFLAGS=
CXXFLASG=
OBJS=linear.o tensor.o conv2d.o maxpool2d.o pugixml.o reshape.o

lenet: main.cpp $(OBJS)
    $(CXX) $(DEBUG) -o lenet main.cpp $(OBJS) $(STD) $(LDFLAGS)

linear.o: linear.cpp
    $(CXX) $(DEBUG) -c linear.cpp $(STD) $(CXXLFAGS)

tensor.o: tensor.cpp tensor.h
    $(CXX) $(DEBUG) -c tensor.cpp $(STD) $(CXXLFAGS)

conv2d.o: conv2d.cpp conv2d.h
    $(CXX) $(DEBUG) -c conv2d.cpp $(STD) $(CXXLFAGS)

maxpool2d.o: maxpool2d.cpp maxpool2d.h
    $(CXX) $(DEBUG) -c maxpool2d.cpp $(STD) $(CXXLFAGS)

reshape.o: reshape.cpp reshape.h
    $(CXX) $(DEBUG) -c reshape.cpp $(STD) $(CXXLFAGS)

pugixml.o: pugixml.cpp pugixml.hpp
    $(CXX) $(DEBUG) -c pugixml.cpp $(STD) $(CXXLFAGS)

clean:
    rm -rf lenet
    rm -rf $(OBJS)

13.lenet.xml

<?xml version="1.0" encoding="UTF-8"?>
<lenet>
    <node>
        <type>input</type>
        <name>input</name>
        <value>1, 28, 28</value>
    </node>
    <node>
        <type>Conv2d</type>
        <name>conv1</name>
        <value>1, 6, 5</value>
        <actfunc>relu</actfunc>
    </node>
    <node>
        <type>MaxPool2d</type>
        <name>pool1</name>
        <value>2, 2</value>
    </node>
    <node>
        <type>Conv2d</type>
        <name>conv2</name>
        <value>6, 16, 5</value>
        <actfunc>relu</actfunc>
    </node>
    <node>
        <type>MaxPool2d</type>
        <name>pool2</name>
        <value>2, 2</value>
    </node>
    <node>
        <type>reshape</type>
        <name>view</name>
        <value>-1, 256</value>
    </node>
    <node>
        <type>Linear</type>
        <name>fc1</name>
        <value>256, 120</value>
        <actfunc>relu</actfunc>
    </node>
    <node>
        <type>Linear</type>
        <name>fc2</name>
        <value>120, 84</value>
        <actfunc>relu</actfunc>
    </node>
    <node>
        <type>Linear</type>
        <name>fc3</name>
        <value>84, 10</value>
    </node>
</lenet>

14.input.txt

-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,  0.6450,
 1.9305,  1.5996,  1.4978,  0.3395,  0.0340, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,  2.4015,
 2.8088,  2.8088,  2.8088,  2.8088,  2.6433,  2.0960,  2.0960,
 2.0960,  2.0960,  2.0960,  2.0960,  2.0960,  2.0960,  1.7396,
 0.2377, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,  0.4286,
 1.0268,  0.4922,  1.0268,  1.6505,  2.4651,  2.8088,  2.4396,
 2.8088,  2.8088,  2.8088,  2.7578,  2.4906,  2.8088,  2.8088,
 1.3577, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.2078,  0.4159, -0.2460,
 0.4286,  0.4286,  0.4286,  0.3268, -0.1569,  2.5797,  2.8088,
 0.9250, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242,  0.6322,  2.7960,  2.2360,
-0.1951, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.1442,  2.5415,  2.8215,  0.6322,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242,  1.2177,  2.8088,  2.6051,  0.1358,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242,  0.3268,  2.7451,  2.8088,  0.3649, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242,  1.2686,  2.8088,  1.9560, -0.3606, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.3097,  2.1851,  2.7324,  0.3140, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242,  1.1795,  2.8088,  1.8923, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
 0.5304,  2.7706,  2.6306,  0.3013, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.1824,
 2.3887,  2.8088,  1.6887, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.3860,  2.1596,
 2.8088,  2.3633,  0.0213, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242,  0.0595,  2.8088,
 2.8088,  0.5559, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.0296,  2.4269,  2.8088,
 1.0395, -0.4115, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242,  1.2686,  2.8088,  2.8088,
 0.2377, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242,  0.3522,  2.6560,  2.8088,  2.8088,
 0.2377, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242,  1.1159,  2.8088,  2.8088,  2.3633,
 0.0849, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242,  1.1159,  2.8088,  2.2105, -0.1951,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242

15.编译源码

$ make

16.运行及其结果

$ ./lenet
predict: 7
上一篇下一篇

猜你喜欢

热点阅读