人工不智能

深入浅出LSTM&GRU

2020-02-04  本文已影响0人  A君来了

我在深入浅出RNN一文中提过,RNN缺乏对信息的调控机制,无法有效利用已学到的信息(hidden state)。RNN机械地将每次学到的新知(hidden state)都揉进同一个的hidden state并将它随着循环传递下去,这样做虽然可以一直保留句子中每个token的信息,但是,当这种雨露均沾模式,遇到长句时,句首信息在hidden state中的占比就会很小,换句话说,它很容易会忘掉长句的句首甚至句中的信息。不仅如此,每个time的训练都会受到历史信息的影响(hidden state + input)。

Figure 1: LSTM & GRU

LSTM和GRU的出现就是为了弥补RNN的这些缺陷。本文将会以重构LSTM和GRU的方式来剖析LSTM和GRU,点击【这里】可以查看完整源码。

LSTM

正如Figure 1所示,LSTM通过引入input gate、forget gate和output gate来调控input和hidden state。这里的gate是由sigmoid函数实现的,它可以将任意input转换成0~1的值,将这些值和hidden state进行element-wise相乘运算,就起到了前文所说的信息调控作用:0表示丢弃信息,1表示完整地保留信息,0.x表示按比例保留信息。

Figure 2: Sigmoid

强烈推荐illustrated-guide-to-lstms-and-gru-s-a-step-by-step-explanation这篇博文,作者通过动图将各个gate的数据流清晰地呈现在读者面前,例如,forget gate:

Figure 3: forget gate.gif

nn.LSTM

class Model5(nn.Module):
  def __init__(self):
    super().__init__()
    self.emb = nn.Embedding(nv, wordvec_len)
    self.input = nn.Linear(wordvec_len, nh)
    self.rnn = nn.LSTM(nh, nh, 1, batch_first=True)
    self.out = nn.Linear(nh, nv)
    self.bn = BatchNorm1dFlat(nh)
    self.h = torch.zeros(1, bs, nh).cuda()
    self.c = torch.zeros(1, bs, nh).cuda()
  
  def forward(self, x):
    res, (h, c) = self.rnn(self.input(self.emb(x)), (self.h, self.c))
    self.h = h.detach()
    self.c = c.detach()
    return self.out(self.bn(res))

通过Pytorch提供的nn.LSTM,可以轻易地构建出基于LSTM的RNN。

def lstm_loop(cell, x, h):
  hx, cx = [], []
  h, c = h
  for o in x.transpose(0, 1): # time loop
    h, c = cell(o, (h, c))
    hx.append(h)
    cx.append(c)
  # reset shape: [batch, time, hidden size]
  return [torch.stack(hx, dim=1), torch.stack(cx, dim=1)]

class Model6(Model5):
  def __init__(self):
    super().__init__()
    self.h = torch.zeros(bs, nh).cuda()
    self.c = torch.zeros(bs, nh).cuda()
    self.cell = nn.LSTMCell(nh, nh)

  def forward(self, x):
    x = F.relu(self.input(self.emb(x)))
    h, c = lstm_loop(self.cell, x, (self.h, self.c))
    self.h = h[:, -1].detach()
    self.c = c[:, -1].detach()
    return self.out(self.bn(h))

RNN的工作原理是循环调用隐藏层来处理每个input(回看深入浅出RNN的Model1)。在这里,lstm_loop是循环体,nn.LSTMCell则是隐藏层。

在验证了lstm_loop的正确性后,接着就是自己动手写nn.LSTMCell。首先,需要先拿到LSTM的数学公式,它们可以在Pytorch的nn.LSTMCell源代码的注释中找到。在Jupyter Notebook中查看函数或类的源代码的方法很简单,只要在模块前加上“??”:

??nn.LSTMCell

\begin{array}{ll} i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi})\\ f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}

上述数学公式中的ifo分别是input gate、forget gate和output gate,c'h'分别是要输出到下一个time的cell state和hidden state,\sigma是sigmoid,W_{ii} x + b_{ii}则是linear(x)。

class LSTMCell(nn.Module):
  def __init__(self, nin, nh):
    super().__init__()
    self.lin_x = nn.Linear(nin, 4 * nh)
    self.lin_h = nn.Linear(nh, 4 * nh)

  def forward(self, x, hc):
    h, c = hc
    _x = self.lin_x(x)
    _h = self.lin_h(h)
    x_i, x_f, x_o, x_g = _x.chunk(4, dim=1)
    h_i, h_f, h_o, h_g = _h.chunk(4, dim=1)
    i = torch.sigmoid(x_i + h_i)
    f = torch.sigmoid(x_f + h_f)
    o = torch.sigmoid(x_o + h_o)
    g = torch.tanh(x_g + h_g)
    c_hat = f * c + i * g
    h_hat = o * torch.tanh(c_hat)
    return (h_hat, c_hat)

在LSTMCell中,将xh的hidden sizes乘以4:self.lin_x = nn.Linear(nin, 4 * nh),再把它们等分成4份:x_i, x_f, x_o, x_g = _x.chunk(4, dim=1),这样就能将8个linear layer计算合并成2个来提升计算速度。

GRU

GRU的工作原理和LSTM很相似,也通过各种gate来调控信息,如Figure 1所示,reset gate和forget gate的功能相同,update gate则与input gate功能类似。和LSTM不同的是,GRU放弃了cell state,相应地,也就不需要output gate来生成hidden state,所以GRU的所需的计算量相比LSTM减少了1/4。模型效果类似,但计算得更快,这些特性让GRU越来越受到推崇。

按照重写LSTMCell的方法,很容易也可以重写GRUCell:

class GRUCell(nn.Module):
  def __init__(self, nin, nh):
    super().__init__()
    self.lin_x = nn.Linear(nin, 3 * nh)
    self.lin_h = nn.Linear(nh, 3 * nh)

  def forward(self, x, h):
    _x = self.lin_x(x)
    _h = self.lin_h(h)
    ir, iz, xin = _x.chunk(3, dim=1)
    hr, hz, hn = _h.chunk(3, dim=1)
    r = torch.sigmoid(ir + hr)  # reset gate
    z = torch.sigmoid(iz + hz)  # update gate
    n = torch.tanh(xin + r * hn)  # new gate
    h_hat = (1 - z) * n + (z * h)
    return h_hat

END

本文通过重构LSTM和GRU的方式详解了这两个模型的工作原理,其中的关键就是理解它们的数学公式。很多人一谈到数学公式,就本能地排斥,希望通过文字、图示这些直觉上感觉更直观的方式来辅助编程,殊不知,很多时候根据数学公式来编程反而更简单。

上一篇下一篇

猜你喜欢

热点阅读