深度学习

自己实现的BatchNormalize层

2020-03-12  本文已影响0人  追光者876

个人认为BatchNormalize是一个非常重要但是却很容易被忽略的知识点,目前几乎所有的神经网络都会用到。我在用cifar10数据集测试时,发现同样的网络,有bn要比没有bn层的验证集准确率提高10%左右。这也验证了吴恩达老师在课中所讲的bn层会有轻微的正则化效果。

class BatchNormalize(tf.keras.layers.Layer):
  def __init__(self, name='BatchNormal', **kwargs):
    super(BatchNormalize, self).__init__(name=name, **kwargs)
    self._epsilon = 0.001
    self._decay = 0.99
  def build(self, input_shape):
    self._mean = self.add_weight(name='mean', shape=[input_shape[-1]], dtype=tf.float32, initializer=tf.zeros_initializer(), trainable=False)
    self._variance = self.add_weight(name="variance", shape=[input_shape[-1]], dtype=tf.float32, initializer=tf.ones_initializer(), trainable=False)
    self._gamma = self.add_weight(name='gamma', shape=[input_shape[-1]], dtype=tf.float32, initializer=tf.ones_initializer(), trainable=True)
    self._beta = self.add_weight(name="beta", shape=[input_shape[-1]], dtype=tf.float32, initializer=tf.zeros_initializer(), trainable=True)
    self._axes = [0, 1, 2]
    if len(input_shape) == 2:
        self._axes = [0]
  def call(self, inputs, training=None):
      if training:
          batch_mean, batch_variance = tf.nn.moments(inputs, axes=self._axes, keep_dims=False, name='moment')
          train_mean = self._mean.assign(tf.add(tf.multiply(self._mean, self._decay), tf.multiply(batch_mean, tf.math.subtract(1.0, self._decay))))
          train_variance = self._variance.assign(tf.add(tf.multiply(self._variance, self._decay), tf.multiply(batch_variance, tf.math.subtract(1.0, self._decay))))
          with tf.control_dependencies([train_mean, train_variance]):
              return tf.nn.batch_normalization(inputs, batch_mean, batch_variance, self._beta, self._gamma, self._epsilon, name="batch_normal")
      else:
          return tf.nn.batch_normalization(inputs, self._mean, self._variance, self._beta, self._gamma, self._epsilon)
上一篇下一篇

猜你喜欢

热点阅读