pytorch实现RNN以及LSTM/GRU

2021-01-10  本文已影响0人  升不上三段的大鱼

pytorch提供了很方便的RNN模块,以及其他结构像LSTM和GRU。
pytorch里的RNN需要的参数主要有:

代码实现很简单,对于一个一层的RNN,实现多对一的分类:

class RNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_class):
        super(RNN,self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_dim, hidden_dim, num_layers=num_layers, batch_first=True)
        # input: (batch_size, sequence_size, input_size)
        # many to one mode
        self.fc = nn.Linear(hidden_dim, num_class)

    def forward(self,x):
        # initialize hidden state
        h0 = torch.zeros(self.num_layers, x.size(0),self.hidden_dim).to('cuda:0')
        # output size: (batch_size, sequnce_size, hidden_size)
        out, _ = self.rnn(x)
        out = self.fc(out)
        return out

多对一指的是一个序列的输入对应着一个值的输出,多对多指的是一个序列输入对应着一个序列的输出。
RNN层的输出有两个,一个是最后一层RNN在所有时间上的输出,另一个是最后一个隐含状态,这里我们只需要一个输出就够了,再加上一个线性层用于分类。
LSTM/GRU和RNN的用法基本一致,只有改一下名字就行了。

对于一个一维的序列,假设batch size为128, 序列长度为256,输入维度为1,分类数目为3,输入的shape为(128,256,1),如果是多对一的分类,得到的输出为(128,3);如果是多对多的分类,得到的输出是(128,256,3).在输入序列的256个时间点上都有输出。

上一篇下一篇

猜你喜欢

热点阅读