RNN梯度消失与梯度爆炸推导
2021-10-14 本文已影响0人
walkerwzy
![](https://img.haomeiwen.com/i1859625/1a8b2a5bad603588.png)
梯度消失与爆炸
假设一个只有 3 个输入数据的序列,此时我们的隐藏层 h1、h2、h3 和输出 y1、y2、y3 的计算公式:
RNN 在时刻 t 的损失函数为 Lt,总的损失函数为
t = 3 时刻的损失函数 L3 对于网络参数 U、W、V 的梯度如下:
其实主要就是因为:
- 对V求偏导时,
是常数
- 对U求偏导时:
-
里有U,所以要继续对h3应用
chain rule
-
里的
是常数,但是
里又有U,继续
chain rule
- 以此类推,直到
-
- 对W求偏导时一样
所以:
- 参数矩阵 V (对应输出
) 的梯度很显然并没有长期依赖
- U和V显然就是连乘(
)后累加(
)
其中的连乘项就是导致 RNN 出现梯度消失与梯度爆炸的罪魁祸首,连乘项可以如下变换:
tanh' 表示 tanh 的导数,可以看到 RNN 求梯度的时候,实际上用到了 (tanh' × W) 的连乘。当 (tanh' × W) > 1 时,多次连乘容易导致梯度爆炸;当 (tanh' × W) < 1 时,多次连乘容易导致梯度消失。