python与Tensorflow

Tensorflow——BatchNormalization(t

2018-10-30  本文已影响8人  SpareNoEfforts

批标准化

tensorflow相应API

tensorflow及python实现

计算每个列的均值及方差。

import tensorflow as tf
W = tf.constant([[-2.,12.,6.],[3.,2.,8.]], )
mean,var = tf.nn.moments(W, axes = [0])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    resultMean = sess.run(mean)
    print(resultMean)
    resultVar = sess.run(var)
    print(resultVar)
#输出
[ 0.5      7.    7.  ]
[ 6.25   25.     1.  ]

标准化

size = 3
scale = tf.Variable(tf.ones([size]))
shift = tf.Variable(tf.zeros([size]))
epsilon = 0.001
W = tf.nn.batch_normalization(W, mean, var, shift, scale, epsilon)
#参考下图BN的公式,相当于进行如下计算
#W = (W - mean) / tf.sqrt(var + 0.001)
#W = W * scale + shift

with tf.Session() as sess:
    #必须要加这句不然执行多次sess会报错
    sess.run(tf.global_variables_initializer())
    resultW = sess.run(W)
    print(resultW)
#观察初始W第二列 12>2 返回BN的W值第二列第二行是负的,其余两列相反

#输出
[[-0.99992001  0.99997997 -0.99950027]
 [ 0.99991995 -0.99997997  0.99950027]]
上一篇下一篇

猜你喜欢

热点阅读