Tensorflow——BatchNormalization(t
2018-10-30 本文已影响8人
SpareNoEfforts
批标准化
- 批标准化(batch normalization,BN)一般用在激活函数之前,使结果 各个维度均值为0,方差为1。通过规范化让激活函数分布在线性区间,让每一层的输入有一个稳定的分布会有利于网络的训练。
- 优点:
加大探索步长,加快收敛速度。
更容易跳出局部极小。
破坏原来的数据分布,一定程度上防止过拟合。
解决收敛速度慢和梯度爆炸。
tensorflow相应API
-
mean, variance = tf.nn.moments(x, axes, name=None, keep_dims=False)
- 计算统计矩,mean 是一阶矩即均值,variance 则是二阶中心矩即方差,axes=[0]表示按列计算;
- 对于以feature map 为维度的全局归一化,若feature map 的shape 为[batch, height, width, depth],则将axes赋值为[0, 1, 2]
-
tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None)
- tf.nn.batch_norm_with_global_normalization(x, mean, variance, beta, gamma, variance_epsilon, scale_after_normalization, name=None);
- tf.nn.moments 计算返回的 mean 和 variance 作为 tf.nn.batch_normalization 参数调用;
tensorflow及python实现
计算每个列的均值及方差。
import tensorflow as tf
W = tf.constant([[-2.,12.,6.],[3.,2.,8.]], )
mean,var = tf.nn.moments(W, axes = [0])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
resultMean = sess.run(mean)
print(resultMean)
resultVar = sess.run(var)
print(resultVar)
#输出
[ 0.5 7. 7. ]
[ 6.25 25. 1. ]
标准化
size = 3
scale = tf.Variable(tf.ones([size]))
shift = tf.Variable(tf.zeros([size]))
epsilon = 0.001
W = tf.nn.batch_normalization(W, mean, var, shift, scale, epsilon)
#参考下图BN的公式,相当于进行如下计算
#W = (W - mean) / tf.sqrt(var + 0.001)
#W = W * scale + shift
with tf.Session() as sess:
#必须要加这句不然执行多次sess会报错
sess.run(tf.global_variables_initializer())
resultW = sess.run(W)
print(resultW)
#观察初始W第二列 12>2 返回BN的W值第二列第二行是负的,其余两列相反
#输出
[[-0.99992001 0.99997997 -0.99950027]
[ 0.99991995 -0.99997997 0.99950027]]