Batch Norm和Layer Norm

2020-06-23  本文已影响0人  京漂的小程序媛儿

深度学习中的ICS问题?

covariate shift 是分布不一致假设之下的一个分支问题,它是指源空间和目标空间的条件概率是一致的,但是其边缘概率不同。
而统计机器学习中的一个经典假设是 “源空间(source domain)和目标空间(target domain)的数据分布(distribution)是一致的”。

ICS导致什么问题?

每个神经元的输入数据不再是 “独立同分布”。
其一,上层参数需要不断适应新的输入数据分布,降低学习速度。
其二,下层输入的变化可能趋向于变大或者变小,导致上层落入饱和区,使得学习过早停止。
其三,每层的更新都会影响到其它层,因此每层的参数更新策略需要尽可能的谨慎。

引入 Normalization

由于 ICS 问题的存在,x 的分布可能相差很大。要解决独立同分布的问题,“理论正确” 的方法就是对每一层的数据都进行白化操作。然而标准的白化操作代价高昂,且不可微不利于反向传播更新梯度。
因此,以 BN 为代表的 Normalization 方法退而求其次,进行了简化的白化操作。
基本思想是:在将 x 送给神经元之前,先对其做平移和伸缩变换, 将 x 的分布规范化成在固定区间范围的标准分布。

通用变换框架

通用归一化框架

最终得到的数据符合均值为 b、方差为g平方的分布。

为什么要再平移再缩放?

为了保证模型的表达能力不因为规范化而下降。
第一步的规范化会将几乎所有数据映射到激活函数的非饱和区(线性区),仅利用到了线性变化能力,从而降低了神经网络的表达能力。而进行再变换,则可以将数据从线性区变换到非线性区,恢复模型的表达能力。

平移参数和再平移参数的区别

平移参数,x 的均值取决于下层神经网络的复杂关联;
但再平移参数中,去除了与下层计算的密切耦合。新参数很容易通过梯度下降来学习,简化了神经网络的训练。

Batch Norm

BN在batch维度的归一化,也就是对于每个batch,该层相应的output位置归一化所使用的mean和variance都是一样的。

BN的学习参数包含rescale和shift两个参数。
1、BN在单独的层级之间使用比较方便,比如CNN。得像RNN这样层数不定,直接用BN不太方便,需要对每一层(每个time step)做BN,并保留每一层的mean和variance。不过由于RNN输入不定长(time step长度不定),可能会有validation或test的time step比train set里面的任何数据都长,因此会造成mean和variance不存在的情况。
2、BN会引入噪声(因为是mini batch而不是整个training set),所以对于噪声敏感的方法(如RL)不太适用。

Layer Norm

LayerNorm实际就是对隐含层做层归一化,即对某一层的所有神经元的输入进行归一化。(每hidden_size个数求平均/方差)
1、它在training和inference时没有区别,只需要对当前隐藏层计算mean and variance就行。不需要保存每层的moving average mean and variance。
2、不受batch size的限制,可以通过online learning的方式一条一条的输入训练数据。
3、LN可以方便的在RNN中使用。
4、LN增加了gain和bias作为学习的参数。


Layer Norm求均值方差
Layer Norm

Bert layer norm实现代码

class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        # mean(-1) 表示 mean(len(x)), 这里的-1就是最后一个维度,也就是hidden_size维
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
上一篇 下一篇

猜你喜欢

热点阅读