BN(Batch Normalization)和TF2的BN层

2019-09-29  本文已影响0人  WILeroy

1、Batch Normalization

在讨论Batch Normalization之前,先讨论一下feature scaling可能会对后续的讨论有很大的帮助。feature scaling,即特征归一化,是机器学习领域中一种通用的数据预处理方法,其目的是将模式向量中尺度不一致的不同维度特征归一到同一尺度,以保证训练速度与精度。

假设有一个大小为n的数据集X^{1,...,n},其中每个模式向量有m个维度的特征X^{i} = {x_{1,...,m}}。如果在这个数据集中,第i维的特征x_{i}服从均值为0、方差为1的高斯分布,而第j维的特征x_{j}服从均值为200,方差为1的高斯分布,那么这个数据集将难以用于模型训练。其原因在于,x_{i}x_{j}的分布相差甚远,模型中与x_{i}相关的参数只进行很小的改变往往难以对结果造成显著性的改变,而与x_{j}相关的参数则恰恰相反,这让训练过程的learning rate很难统一,过小收敛过慢,过大则可能不收敛。

为了解决以上问题,feature scaling对每个维度的特征都进行如下变换,变换的结果则是所有维度的特征都归一化到均值为0、方差为1这个尺度:

\hat{x}_{i}^{r} = \frac{x_{i}^{r} - m_{i}}{\sigma_{i}}

以上方法对于模型的训练是十分有效的,而在深度神经网络的研究中,研究人员延续这种思路提出了Batch Normalization。相对于传统的模型,深度神经网络遇到的问题是,随着网络深度增加,网络中一个小小的改变可能在经过若干层的传播之后令整个网络出现极大的波动,如bp过程中的梯度消失与爆炸(事实上,ReLU、有效的初始化、设置更小的learning rate等方法都能用于解决该问题)。

Batch Normalization可以用于解决深度神经网络的Internal Covariate Shift问题,其实质是:使用一定的规范化方法,把每个隐层神经元的输入控制为均值为0、方差为1的标准正态分布,使得非线性变换函数的输入值落入对输入比较敏感的区域(如Sigmoid函数只在0附近具有较好的梯度),以此避免梯度消失问题。

在Batch Normalization中,Batch是指每次训练时网络的输入都是一批训练数据,这一批数据会同时经过网络的一层,然后在经过WX^{i}+b=Z^{i}之后,网络再一起对这一批数据的Z^{i}做规范化处理。当然,Batch Normalization的论文中还使用了两个参数处理规范化之后的数据,即\hat{Z}^{i} = \gamma\odot\tilde{z}^{i}+\beta。事实上,如果\gamma=\sigma\beta=\mu,这就等价于Normalization的一个逆运算,那么normalization的意义似乎就不存在了,但是,事实并非如此,因为\mu\sigmaZ^{i}相关,而\gamma\beta则完全独立,二者并不等价。合理的解释是,后续操作是为了防止normalization矫枉过正增加的人为扰动。Batch Normalization的具体结构如下所示:

bn

2、TF2的BN层

在tensorflow2中使用BN层的方法如下,需要注意的是BN层在训练和推理两种模式下存在不同。

BN层有4*num_channels个参数,每4个参数对应一个通道,分别是\mu, \sigma, \beta, \gamma。其中\beta, \gamma和其他层的参数的逻辑是一致的,训练时不断调整,推理时不再改变(即只有优化器更新参数时才会改变)。而\mu, \sigma不同,在推理时,即使没有优化器更新参数,也可能不断变化。这两个参数受BatchNormalization层的参数training控制,当training=False时,二者为移动均值和方差(固定);当training=True时,二者与每次输入的batch相关,\mu, \sigma是当前batch的均值、方差。

综上,在使用TF2的BN层时,推理时需要指定当前模式为推理模式,方法如下(还存在其他方法,如显示地声明training参数为False)。此外,BN层也有trainable参数,和其他层一样,该参数意在冻结\beta, \gamma两个参数,但是当trainable=True时,该BN层会以推理模式运行,\mu, \sigma两个参数也就随之固定。

import tensorflow as tf

# BN层的使用
tf.keras.layers.BatchNormalization()(x)

# 训练、推理模式的选择,0-推理、1-训练
tf.keras.backend.set_learning_phase(0)
上一篇下一篇

猜你喜欢

热点阅读