c++实现rnn推理模型

2024-01-20  本文已影响0人  一路向后

1.问题

用RNN实现输入一个字母,预测出下一个字母:
输入a, 预测出b
输入b, 预测出c
输入c, 预测出d
输入d, 预测出e
输入e, 预测出a

2.tensor.h

#ifndef _CONVNET_TENSOR_H_
#define _CONVNET_TENSOR_H_

#include <vector>

typedef double Real;

namespace convnet {

    class Tensor {
    public:
        Tensor();
        Tensor(int a);
        Tensor(int a, int b);
        Tensor(int a, int b, int c);
        ~Tensor();

        void resize(int a);
        void resize(int a, int b);
        void resize(int a, int b, int c);

        void relu();
        void tanh();
        void sigmoid();
        void argmax(int &s);

        void set(std::vector<Real> &data);
        int size();

    private:
        friend class Linear;
        friend class Conv2d;
        friend class MaxPool2d;
        friend class Reshape;
        friend class Rnn;

        std::vector<int> dim;
        std::vector<Real> data;
    };

}

#endif

3.tensor.cpp

#include <math.h>
#include <iostream>
#include "tensor.h"

using namespace std;

convnet::Tensor::Tensor()
{
    //dim.clear();
}

convnet::Tensor::Tensor(int a)
{
    dim.resize(1);

    dim[0] = a;

    if(dim[0] > 0)
    {
        data.resize(a);
    }
}

convnet::Tensor::Tensor(int a, int b)
{
    dim.resize(2);

    dim[0] = a;
    dim[1] = b;

    if(a*b > 0)
    {
        data.resize(a*b);
    }
}

convnet::Tensor::Tensor(int a, int b, int c)
{
    dim.resize(3);

    dim[0] = a;
    dim[1] = b;
    dim[2] = c;

    if(a*b*c > 0)
    {
        data.resize(a*b*c);
    }
}

void convnet::Tensor::resize(int a)
{
    dim.resize(1);

    dim[0] = a;

    if(dim[0] > 0  && a != data.size())
    {
        data.resize(a);
    }
}

void convnet::Tensor::resize(int a, int b)
{
    dim.resize(2);

    dim[0] = a;
    dim[1] = b;

    if(a*b > 0 && a*b != data.size())
    {
        data.resize(a*b);
    }
}

void convnet::Tensor::resize(int a, int b, int c)
{
    if(dim.size() != 0)
    {
        dim.clear();
    }

    if(data.size() != 0)
    {
        data.clear();
    }

    dim.resize(3);

    dim[0] = a;
    dim[1] = b;
    dim[2] = c;

    if(a*b*c > 0)
    {
        data.resize(a*b*c);
    }
}

convnet::Tensor::~Tensor()
{
    dim.clear();
}

void convnet::Tensor::set(std::vector<Real> &data)
{
    this->data = data;
}

int convnet::Tensor::size()
{
    return data.size();
}

void convnet::Tensor::relu()
{
    for(int i=0; i<data.size(); i++)
    {
        if(data[i] < 0)
        {
            data[i] = 0;
        }
    }
}

void convnet::Tensor::sigmoid()
{
    for(int i=0; i<data.size(); i++)
    {
        data[i] = 1 / (1+expf(-data[i]));
    }
}

void convnet::Tensor::argmax(int &s)
{
    if(dim.size() == 1 && dim[0] > 0)
    {
        int u = 0;
        int i = 1;

        for(i=1; i<dim[0]; i++)
        {
            if(data[i] >= data[u])
            {
                u = i;
            }
        }

        s = u;
    }
}

void convnet::Tensor::tanh()
{
    if(dim.size() == 1 && dim[0] > 0)
    {
        int i = 1;

        for(i=0; i<dim[0]; i++)
        {
            data[i] = std::tanh(data[i]);
        }
    }
}

4.rnn.h

#ifndef _CONVNET_RNN_H_
#define _CONVNET_RNN_H_

#include <string>
#include "tensor.h"

namespace convnet {

    class Rnn {
    public:
        Rnn(int m, int n);
        ~Rnn();

        void set(Tensor *i, Tensor *h);
        void setargs(std::vector<Real> &w1, std::vector<Real> &w2, std::vector<Real> &b1, std::vector<Real> &b2);
        void forward();
        void print(int type);

    private:
        int m;
        int n;
        Tensor sw;
        Tensor sb;
        Tensor ow;
        Tensor ob;
        Tensor ht;
        Tensor st;
        Tensor *input;
        Tensor *hidden;
    };

}

#endif

5.rnn.cpp

#include <cassert>
#include "rnn.h"

convnet::Rnn::Rnn(int m, int n)
{
    this->m = m;
    this->n = n;

    sw.resize(m, n);
    sb.resize(n);
    ow.resize(n, n);
    ob.resize(n);
    st.resize(n);
    ht.resize(n);

    for(int s=0; s<n; s++)
    {
        ht.data[s] = 0.0;
    }
}

convnet::Rnn::~Rnn()
{

}

void convnet::Rnn::set(Tensor *i, Tensor *h)
{
    this->input = i;
    this->hidden = h;

    if(n != 0 && h->size() != n)
    {
        h->resize(n);
    }
}

void convnet::Rnn::forward()
{
    int i, j;

    for(i=0; i<n; i++)
    {
        st.data[i] = 0.0;

        for(j=0; j<m; j++)
        {
            st.data[i] += input->data[j] * sw.data[m*i + j];
        }

        st.data[i] += sb.data[i];

        for(j=0; j<n; j++)
        {
            st.data[i] += ht.data[j] * ow.data[n*i + j];
        }

        st.data[i] += ob.data[i];
    }

    st.tanh();

    for(i=0; i<n; i++)
    {
        ht.data[i] = st.data[i];
        hidden->data[i] = st.data[i];
    }
}

void convnet::Rnn::setargs(std::vector<Real> &w1, std::vector<Real> &w2, std::vector<Real> &b1, std::vector<Real> &b2)
{
    assert(w1.size() == m * n && b1.size() == n);
    assert(w2.size() == n * n && b2.size() == n);

    sw.set(w1);
    ow.set(w2);
    sb.set(b1);
    ob.set(b2);
}

void convnet::Rnn::print(int type)
{
    if(type == 0)
    {
        for(int i=0; i<m; i++)
        {
            printf("%.6lf ", input->data[i]);
        }

        printf("\n");
    }
    else if(type == 1)
    {
        for(int i=0; i<n; i++)
        {
            printf("%.6lf ", hidden->data[i]);
        }

        printf("\n");
    }
}

6.linear.h

#ifndef _CONVNET_LINEAR_H_
#define _CONVNET_LINEAR_H_

#include "tensor.h"

namespace convnet {

    class Linear {
    public:
        Linear(int m, int n);

        void set(Tensor *i, Tensor *o);
        void forward();
        void setargs(std::vector<Real> &w, std::vector<Real> &b);
        void print(int type);

    private:
        int m;
        int n;
        Tensor weight;
        Tensor bias;
        Tensor *input;
        Tensor *output;
    };

}

#endif

7.linear.cpp

#include <stdio.h>
#include <cassert>
#include "linear.h"

convnet::Linear::Linear(int m, int n)
{
    this->m = m;
    this->n = n;

    weight.resize(m, n);
    bias.resize(n);
}

void convnet::Linear::set(Tensor *i, Tensor *o)
{
    //printf("m = %d, %d\n", m, i->size());

    if(m != i->size())
    {
        i->resize(m);
    }

    if(n != o->size())
    {
        o->resize(n);
    }

    input = i;
    output = o;
}

void convnet::Linear::setargs(std::vector<Real> &w, std::vector<Real> &b)
{
    assert(w.size() == m * n && b.size() == n);
    {
        weight.data = w;
        bias.data = b;
    }
}

void convnet::Linear::forward()
{
    for(int out=0; out<n; out++)
    {
        output->data[out] = 0.0;

        for(int in=0; in<m; in++)
        {
            output->data[out] += input->data[in] * weight.data[m*out + in];
        }

        output->data[out] += bias.data[out];
    }
}

void convnet::Linear::print(int type)
{
    if(type == 0)
    {
        for(int i=0; i<m; i++)
        {
            printf("%.6lf ", input->data[i]);
        }

        printf("\n");
    }
    else if(type == 1)
    {
        for(int i=0; i<n; i++)
        {
            printf("%.6lf ", output->data[i]);
        }

        printf("\n");
    }
}

8.main.cpp

#include <iostream>
#include <vector>
#include "tensor.h"
#include "rnn.h"
#include "linear.h"

using namespace std;
using namespace convnet;

int main()
{
    Tensor *input = new Tensor(5);
    Tensor output[2];
    Rnn rnn(5, 10);
    Linear linear(10, 5);
    vector<Real> w1 = {
        1.0875e-01,  1.9007e-01,  3.3171e-02, -4.2167e-01,  2.3260e-01,
        -4.0091e-02,  5.0844e-01, -4.9908e-03, -4.6057e-01, -3.2492e-01,
        -4.4841e-01, -1.9543e-01,  4.1142e-01,  3.3704e-01, -7.6349e-02,
        2.7550e-01,  1.1706e-01,  5.2413e-01,  3.6117e-01, -5.5991e-01,
        3.9840e-01,  2.3983e-02, -3.0162e-01, -6.0204e-02, -1.3524e-01,
        3.9064e-04, -4.0791e-01,  2.9194e-01, -3.2485e-01,  1.1633e-01,
        2.1486e-01,  1.3768e-01,  3.2932e-01, -2.1038e-01,  3.9599e-01,
        -2.4395e-01, -4.5648e-02,  3.3629e-01, -4.9821e-01, -9.9775e-02,
        4.2604e-01, -6.1709e-01,  1.4171e-02, -3.5881e-01, -3.0136e-01,
        -1.1071e-01,  4.3087e-01, -7.1492e-02, -4.4776e-02, -1.9427e-01
    };
    vector<Real> w2 = {
         0.1804, -0.0511, -0.1194, -0.3126,  0.3056,  0.2208,  0.2536, -0.2775, 0.1334, -0.0084,
        -0.0078,  0.0143,  0.0403,  0.1966, -0.0028,  0.0869,  0.0081, -0.2408, 0.0628,  0.0728,
        0.0433, -0.2915,  0.2838, -0.1858, -0.0760, -0.2338,  0.0192,  0.2064, -0.0470, -0.2736,
        -0.1543, -0.0061, -0.0271,  0.1564, -0.1332,  0.2041, -0.0063, -0.0483, 0.3013, -0.0242,
        -0.0377,  0.1239, -0.1080,  0.2230,  0.1908, -0.2534, -0.2355, -0.2026, -0.0397, -0.0283,
        0.3074, -0.1016, -0.2998, -0.2427, -0.0007, -0.1828, -0.0867, -0.2579, 0.2764, -0.1827,
        -0.0062,  0.0415, -0.1900,  0.1646, -0.0817,  0.1933, -0.1867, -0.0074, -0.3107, -0.2211,
        0.0158,  0.3108, -0.0322,  0.0481,  0.2690,  0.1093, -0.2631,  0.2370, -0.1548, -0.2132,
        -0.2503,  0.2321, -0.0190, -0.2398,  0.1281, -0.2103,  0.3047,  0.3008, -0.2617, -0.1564,
        0.0903, -0.2276,  0.1263,  0.0693, -0.2775,  0.2864,  0.1292, -0.3017, 0.1994, -0.1917
    };
    vector<Real> w3 = {
        0.2785, -0.4658, -0.1571, -0.7094, -0.3439,  0.1966,  0.2905, -0.1365, -0.1963, -0.0616,
        0.0765,  0.0545, -0.4994,  0.2200,  0.4290,  0.0951, -0.0870, -0.2526, 0.5923, -0.1664,
        0.3581,  0.4772, -0.1968,  0.1624, -0.0177, -0.2860,  0.0372,  0.0018, -0.4826,  0.4231,
        0.2152,  0.0899,  0.4478,  0.4280, -0.4368,  0.3140,  0.1704,  0.4542, 0.2319,  0.0220,
        -0.2668, -0.4524,  0.2827,  0.3360, -0.0086, -0.2787, -0.1835, -0.5607, -0.2900, -0.1088
    };
    vector<Real> b1 = {0.2237, -0.1793,  0.1222, -0.0701, -0.1727,  0.0800,  0.2516, -0.3360, -0.0056,  0.0087};
    vector<Real> b2 = {0.1713, -0.0580, -0.1222,  0.1427,  0.0238, -0.0174,  0.0815,  0.2246, 0.0837,  0.0762};
    vector<Real> b3 = {-0.1141,  0.2620,  0.0653,  0.0189,  0.0603};
    vector<Real> x1 = {1,0,0,0,0};
    char chx1 = 0x00;
    int y1;

    printf("请输入测试字母: ");

    scanf("%c", &chx1);

    if(chx1 == 'a')
    {
        x1 = {1, 0, 0, 0, 0};
    }
    else if(chx1 == 'b')
    {
        x1 = {0, 1, 0, 0, 0};
    }
    else if(chx1 == 'c')
    {
        x1 = {0, 0, 1, 0, 0};
    }
    else if(chx1 == 'd')
    {
        x1 = {0, 0, 0, 1, 0};
    }
    else if(chx1 == 'e')
    {
        x1 = {0, 0, 0, 0, 1};
    }
    else
    {
        delete input;

        return -1;
    }

    rnn.set(input, &output[0]);
    linear.set(&output[0], &output[1]);

    rnn.setargs(w1, w2, b1, b2);
    linear.setargs(w3, b3);

    input->set(x1);
    rnn.forward();
    linear.forward();

    output[1].argmax(y1);

    printf("预测字母为: %c\n", y1+'a');

    delete input;

    return 0;
}

9.Makefile

CXX=g++
STD=-std=c++11
DEBUG=-g
LDFLAGS=
CXXFLASG=
OBJS=tensor.o rnn.o linear.o

rnn: main.cpp $(OBJS)
    $(CXX) $(DEBUG) -o rnn main.cpp $(OBJS) $(STD) $(LDFLAGS)

tensor.o: tensor.cpp tensor.h
    $(CXX) $(DEBUG) -c tensor.cpp $(STD) $(CXXLFAGS)

rnn.o: rnn.cpp rnn.h
    $(CXX) $(DEBUG) -c rnn.cpp $(STD) $(CXXLFAGS)

linear.o: linear.cpp linear.h
    $(CXX) $(DEBUG) -c linear.cpp $(STD) $(CXXFLAGS)

clean:
    rm -rf rnn
    rm -rf $(OBJS)

10.编译源码

$ make

11.运行及其结果

$ ./rnn 
请输入测试字母: a
预测字母为: b
$ ./rnn 
请输入测试字母: b
预测字母为: c
$ ./rnn 
请输入测试字母: c
预测字母为: d
$ ./rnn 
请输入测试字母: d
预测字母为: e
$ ./rnn 
请输入测试字母: e
预测字母为: a
上一篇 下一篇

猜你喜欢

热点阅读