Train Loss保持下降,Valid Loss大幅度波动下降

2018-11-23  本文已影响0人  药柴

在使用PyTorch进行PointCNN的构建和实现中,发现模型在训练过程中Loss保持稳定下降,但是在验证过程中,出现完全不合理的10e9级别的Loss。考虑到训练集和验证集是完全从同一数据集中采样出来的,不可能会在数据分布上出现明显的差异,因此排除数据不一致的原因。
详细检查了模型在训练和验证过程中的输出,发现由于最后的一层BatchNormalization,模型在训练过程中的输出是接近均值为零,方差为一的。而验证过程中,模型的输出完全没有遵从这个分布。因此,可以认为,BatchNormalization在验证过程中,没有发挥它的作用。
考虑到模型内部,显式对数据分布进行调整的计算,还是主要在BatchNormalization层,因此首先调查这一方面。
结果发现,PyTorch Forum上有人提到了相似的问题Model.eval() gives incorrect loss for model with batchnorm layers。在这里,PyTorch Dev, Facebook AI Research的smth提到

it is possible that your training in general is unstable, so BatchNorm’s running_mean and running_var dont represent true batch statistics.
http://pytorch.org/docs/master/nn.html?highlight=batchnorm#torch.nn.BatchNorm1d
Try the following:

  • change the momentum term in BatchNorm constructor to higher.
  • before you set model.eval(), run a few inputs through model (just forward pass, you dont need to backward). This will help stabilize the running_mean / running_std values.

即,BatchNormalization层内的,随训练而不断更新的拟合的数据分布,没有能匹配真实的batch数据分布。
推荐将BatchNorm内的momentum项设置的比较高,或是在将模型调到model.eval()模式前,先将部分测试的数据在模型内前向传播一下,让BatchNorm层可以更新一下这个估计。
本文试验了一下调高momentum项,但没有明显的效果。
本文解决这个问题通过另一位网友cakeeatingpolarbear提出的方法,将BatchNorm函数内的track_running_stats设置为False,则模型会在任何模式下保持进行对数据分布的拟合。

上一篇下一篇

猜你喜欢

热点阅读