数据科学机器学习-入门深度学习

技能 | 三次简化一张图: 一招理解LSTM/GRU门控机制

2017-09-05  本文已影响199人  AI科技大本营

作者 | 张皓

引言


RNN是深度学习中用于处理时序数据的关键技术, 目前已在自然语言处理, 语音识别, 视频识别等领域取得重要突破, 然而梯度消失现象制约着RNN的实际应用。LSTM和GRU是两种目前广为使用的RNN变体,它们通过门控机制很大程度上缓解了RNN的梯度消失问题,但是它们的内部结构看上去十分复杂,使得初学者很难理解其中的原理所在。本文介绍”三次简化一张图”的方法,对LSTM和GRU的内部结构进行分析。该方法非常通用,适用于所有门控机制的原理分析。


预备知识: RNN


RNN (recurrent neural networks, 注意不是recursiveneural networks)提供了一种处理时序数据的方案。和n-gram只能根据前n-1个词来预测当前词不同, RNN理论上可以根据之前所有的词预测当前词。在每个时刻, 隐层的输出ht依赖于当前词输入xt和前一时刻的隐层状态ht-1:



其中:=表示"定义为", sigm代表sigmoid函数sigm(z):=1/(1+exp(-z)), Wxh和Whh是可学习的参数。结构见下图:


图中左边是输入,右边是输出。xt是当前词,ht-1记录了上文的信息。xt和ht-1在分别乘以Wxh和Whh之后相加,再经过tanh非线性变换,最终得到ht。

 在反向传播时,我们需要将RNN沿时间维度展开,隐层梯度在沿时间维度反向传播时需要反复乘以参数。因此, 尽管理论上RNN可以捕获长距离依赖, 但实际应用中,根据谱半径(spectralradius)的不同,RNN将会面临两个挑战:梯度爆炸(gradient explosion)和梯度消失(vanishing gradient)。梯度爆炸会影响训练的收敛,甚至导致网络不收敛;而梯度消失会使网络学习长距离依赖的难度增加。这两者相比, 梯度爆炸相对比较好处理,可以用梯度裁剪(gradientclipping)来解决,而如何缓解梯度消失是RNN及几乎其他所有深度学习方法研究的关键所在。

 

LSTM

LSTM通过设计精巧的网络结构来缓解梯度消失问题,其数学上的形式化表示如下:



其中代表逐元素相乘。这个公式看起来似乎十分复杂,为了更好的理解LSTM的机制, 许多人用图来描述LSTM的计算过程, 比如下面的几张图:

似乎看完了这些图之后,你对LSTM的理解还是一头雾水? 这是因为这些图想把LSTM的所有细节一次性都展示出来,但是突然暴露这么多的细节会使你眼花缭乱,从而无处下手。

 因此,本文提出的方法旨在简化门控机制中不重要的部分,从而更关注在LSTM的核心思想。整个过程是“三次简化一张图”,具体流程如下:





 

和RNN相同的是,网络接受两个输入,得到一个输出。不同之处在于, LSTM中通过3个门控单元来对记忆单元c的信息进行交互。

 

根据这张图,我们可以对LSTM中各单元作用进行分析:





 

GRU


GRU是另一种十分主流的RNN衍生物。RNN和LSTM都是在设计网络结构用于缓解梯度消失问题, 只不过是网络结构有所不同。GRU在数学上的形式化表示如下:



为了理解GRU的设计思想,我们再一次运用“三次简化一张图”的方法来进行分析:



与LSTM相比,GRU将输入门it和遗忘门ft融合成单一的更新门zt,并且融合了记忆单元ct和隐层单元ht,所以结构上比LSTM更简单一些。

 

根据这张图,我们可以对GRU的各单元作用进行分析:



 

小结


尽管RNN, LSTM,和GRU的网络结构差别很大,但是他们的基本计算单元是一致的,都是对xt和ht-1做一个线性映射加tanh激活函数,见三个图的红色框部分。他们的区别在于如何设计额外的门控机制控制梯度信息传播用以缓解梯度消失现象。LSTM用了3个门,GRU用了2个,那能不能再少呢? MGU (minimal gate unit)尝试对这个问题做出回答, 它只有一个门控单元。 


最后留个小练习, 参考LSTM和GRU的例子,你能不能用“三次简化一张图”的方法来分析一下MGU呢?

 


参考文献

1. Bengio, Yoshua, PatriceSimard, and Paolo Frasconi。 "Learning long-term dependencies with gradient descent isdifficult。" IEEE transactions on neural networks 5。2 (1994):157-166。

2. Cho, Kyunghyun, et al。"Learning phrase representations using RNN encoder-decoder for statisticalmachine translation。" arXiv preprint arXiv:1406。1078 (2014)。

3. Chung, Junyoung, et al。"Empirical evaluation of gated recurrent neural networks on sequencemodeling。" arXiv preprint arXiv:1412。3555 (2014)。

4. Gers, Felix。 "Longshort-term memory in recurrent neural networks。" UnpublishedPhD dissertation, Ecole Polytechnique Fédérale de Lausanne, Lausanne, Switzerland(2001)。

5. Goodfellow, Ian, YoshuaBengio, and Aaron Courville。 Deep learning。 MIT press, 2016。

6. Graves, Alex。 Supervisedsequence labelling with recurrent neural networks。 Vol。 385。 Heidelberg:Springer, 2012。

7. Greff, Klaus, et al。 "LSTM:A search space odyssey。" IEEE transactions on neural networks and learning systems(2016)。

8. He, Kaiming, et al。 "Deepresidual learning for image recognition。" Proceedingsof the IEEE conference on computer vision and pattern recognition。 2016。

9. He, Kaiming, et al。"Identity mappings in deep residual networks。" EuropeanConference on Computer Vision。 Springer International Publishing, 2016。

10. Hochreiter, Sepp, and JürgenSchmidhuber。 "Long short-term memory。" Neuralcomputation 9。8 (1997): 1735-1780。

11. Jozefowicz, Rafal, WojciechZaremba, and Ilya Sutskever。 "An empirical exploration of recurrent network architectures。" Proceedingsof the 32nd International Conference on Machine Learning (ICML-15)。 2015。

12. Li, Fei-Fei, JustinJohnson, and Serena Yeung。 CS231n: Convolutional Neural Networks for Visual Recognition。 Stanford。 2017。

13. Lipton, Zachary C。, JohnBerkowitz, and Charles Elkan。 "A critical review of recurrent neural networks for sequencelearning。" arXiv preprint arXiv:1506。00019 (2015)。

14. Manning, Chris andRichard Socher。 CS224n: Natural Language Processing with Deep Learning。 Stanford。 2017。

15. Pascanu, Razvan, Tomas Mikolov, and YoshuaBengio。 "On the difficulty of training recurrent neural networks。"International Conference on Machine Learning。 2013。

16. Srivastava, RupeshKumar, Klaus Greff, and Jürgen Schmidhuber。 "Highwaynetworks。" arXiv preprint arXiv:1505。00387 (2015)。

17. Williams, D。 R。 G。 H。 R。, andGeoffrey Hinton。 "Learning representations by back-propagating errors。"Nature 323。6088 (1986): 533-538。

18. Zhou, Guo-Bing, et al。"Minimal gated unit for recurrent neural networks。"International Journal of Automation and Computing 13。3 (2016):226-234。



本文是投稿文章,作者:张皓

github地址:https://github.com/HaoMood/

注:AI科技大本营现已开通投稿通道,投稿请加编辑微信1092722531

更多资讯请关注微信公众号:AI科技大本营(rgznai100)

上一篇下一篇

猜你喜欢

热点阅读