pytorch实现RNN以及LSTM/GRU
2021-01-10 本文已影响0人
升不上三段的大鱼
pytorch提供了很方便的RNN模块,以及其他结构像LSTM和GRU。
pytorch里的RNN需要的参数主要有:
- input_size:input_tensor的形状是(序列长度, batch大小,input_size)
- hidden_size:可以自己定义大小,是一个需要调的参数,hidden state是(RNN的层数*方向,batch,hidden_size),这里的方向默认是1,如果是双向的RNN,方向则是2.
- num_layers:RNN也可以堆叠起来,默认是1层,可以设置层数。
- batch_first:第一维是非为batch size,默认为false;如果设为true,意味着输入和输出的第一维是batch。
代码实现很简单,对于一个一层的RNN,实现多对一的分类:
class RNN(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, num_class):
super(RNN,self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.rnn = nn.RNN(input_dim, hidden_dim, num_layers=num_layers, batch_first=True)
# input: (batch_size, sequence_size, input_size)
# many to one mode
self.fc = nn.Linear(hidden_dim, num_class)
def forward(self,x):
# initialize hidden state
h0 = torch.zeros(self.num_layers, x.size(0),self.hidden_dim).to('cuda:0')
# output size: (batch_size, sequnce_size, hidden_size)
out, _ = self.rnn(x)
out = self.fc(out)
return out
多对一指的是一个序列的输入对应着一个值的输出,多对多指的是一个序列输入对应着一个序列的输出。
RNN层的输出有两个,一个是最后一层RNN在所有时间上的输出,另一个是最后一个隐含状态,这里我们只需要一个输出就够了,再加上一个线性层用于分类。
LSTM/GRU和RNN的用法基本一致,只有改一下名字就行了。
对于一个一维的序列,假设batch size为128, 序列长度为256,输入维度为1,分类数目为3,输入的shape为(128,256,1),如果是多对一的分类,得到的输出为(128,3);如果是多对多的分类,得到的输出是(128,256,3).在输入序列的256个时间点上都有输出。