(五)tensorflow 1.x中关于BN层的坑
在最近进行模型训练时,遇到了一些BN层的坑,特此记录一下。
问题描述:
模型训练的时候,训练集上的准确率很高,测试集的表现很差,排除了其他原因后,锁定在了slim的batch_norm的使用上。
解决方案
1、设置依赖
slim.batch_norm源码
Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. For example:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
One can set updates_collections=None to force the updates in place, but that can have a speed penalty, especially in distributed settings.
在训练时,moving_mean 和 moving_variance 默认是添加到tf.GraphKeys.UPDATE_OPS 中的, 因此需要作为一个依赖项,在更新train_op时跟新参数。将 updates_collections参数设置为None,这样会在训练时立即更新,影响速度。
2、设置decay参数
Lower decay value (recommend trying decay=0.9) if model experiences reasonably good training performance but poor validation and/or test performance. Try zero_debias_moving_mean=True for improved stability.
由于使用BN层的网络,预测的时候要用到估计的总体均值和方差,如果iteration还比较少的时候就急着去检验或者预测的话,可能这时EMA估计得到的总体均值/方差还不accurate和stable, 所以会造成训练和预测悬殊,这种情况就是造成下面这个issue的原因:https://github.com/tensorflow/tensorflow/issues/7469 解决的办法就是:当训练结果远好于预测的时候,那么可以通过减小decay,早点“热身”。
默认decay=0.999,一般建议使用0.9
3、模型保存
当我们使用batch_norm时,slim.batch_norm中的moving_mean和moving_variance不是trainable的, 所以使用saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3)无法保存, 应该改为:
var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
var_list += bn_moving_vars
saver = tf.train.Saver(var_list=var_list, max_to_keep=3)