Train Loss保持下降,Valid Loss大幅度波动下降
在使用PyTorch进行PointCNN的构建和实现中,发现模型在训练过程中Loss保持稳定下降,但是在验证过程中,出现完全不合理的10e9级别的Loss。考虑到训练集和验证集是完全从同一数据集中采样出来的,不可能会在数据分布上出现明显的差异,因此排除数据不一致的原因。
详细检查了模型在训练和验证过程中的输出,发现由于最后的一层BatchNormalization
,模型在训练过程中的输出是接近均值为零,方差为一的。而验证过程中,模型的输出完全没有遵从这个分布。因此,可以认为,BatchNormalization
在验证过程中,没有发挥它的作用。
考虑到模型内部,显式对数据分布进行调整的计算,还是主要在BatchNormalization
层,因此首先调查这一方面。
结果发现,PyTorch Forum上有人提到了相似的问题Model.eval() gives incorrect loss for model with batchnorm layers。在这里,PyTorch Dev, Facebook AI Research的smth提到
it is possible that your training in general is unstable, so BatchNorm’s running_mean and running_var dont represent true batch statistics.
http://pytorch.org/docs/master/nn.html?highlight=batchnorm#torch.nn.BatchNorm1d
Try the following:
- change the momentum term in BatchNorm constructor to higher.
- before you set model.eval(), run a few inputs through model (just forward pass, you dont need to backward). This will help stabilize the running_mean / running_std values.
即,BatchNormalization
层内的,随训练而不断更新的拟合的数据分布,没有能匹配真实的batch数据分布。
推荐将BatchNorm
内的momentum
项设置的比较高,或是在将模型调到model.eval()
模式前,先将部分测试的数据在模型内前向传播一下,让BatchNorm
层可以更新一下这个估计。
本文试验了一下调高momentum项,但没有明显的效果。
本文解决这个问题通过另一位网友cakeeatingpolarbear提出的方法,将BatchNorm
函数内的track_running_stats
设置为False
,则模型会在任何模式下保持进行对数据分布的拟合。