为什么Recompile之后你的网络不收敛了

2018-06-11  本文已影响0人  AlchemistMartin

最近在做的一个项目需要把训练分为两个阶段,使用一个loss函数训练一段时间之后再换另一个,在初始化模型的时候没有用callback的方式,而是直接进行了两次compile,但由于手贱改了下第一个阶段的学习率,发现本来好好的模型突然不收敛了。也算是keras使用过程中不那么明显的一个坑,在这里记录一下分析的过程和以后避免跳坑的教训。


因为使用的数据比较敏感,所以为便于表达,这里换用MNIST数据集的数据,问题相关的代码:

l1 = Input((in_dims, ))
l2 = Dense(200, activation='relu', kernel_initializer=glorot_uniform(seed=0), bias_initializer='zeros')(l1)
l3 = Dense(50, activation='relu', kernel_initializer=glorot_uniform(seed=0), bias_initializer='zeros')(l2)
l4 = Dense(out_dims, activation='relu', kernel_initializer=glorot_uniform(seed=0), bias_initializer='zeros')(l3)

optimizer_1 = sgd(lr=5e1)
optimizer_2 = sgd(lr=5e-3)

m_model = Model(inputs=l1, outputs=l4)
m_model.compile(optimizer=optimizer_1, loss=categorical_crossentropy, metrics=['acc'])
m_model.fit(x=x_train, y=y_train, batch_size=512, epochs=3, verbose=1, validation_data=(x_val, y_val))

m_model.compile(optimizer=optimizer_2, loss=categorical_crossentropy, metrics=['acc'])
m_model.fit(x=x_train, y=y_train, batch_size=512, epochs=3, verbose=1, validation_data=(x_val, y_val))

模型的收敛情况如下图,lr为5的时候不收敛,这个符合预期,但是在修改了lr之后也没有收敛:

两个阶段模型的收敛情况,左边是lr=5, 右边将lr修改为5e-3重新compile之后训练
而在调换了optimizer_1和optimizer_2的顺序之后,lr为5e-3的阶段还是学习了一点,有一定的精度,到第二个阶段lr改为5才彻底挂掉:
调换optimizer_1和optimizer_2的顺序后,lr为5e-3的部分还是收敛到了一定精度
所以为什么呢?第一反应可能是compile这里出了问题,是不是compile之后原来的学习率还以某种形式在网络中存在?但是在第二张图中我们可以看到compile之后模型的不收敛情况完全符合预期,所以compile不应该引入这样的问题,问题还是和使用这两个学习率的先后顺序有关。compile的作用是配置模型的优化器、损失函数以及metric等,并不会影响模型的内部参数。

我们打开lr=5阶段的TensorBoard再分析一下,这下可以看到是梯度完全消失了,并且权重和偏置基本都是绝对值非常小的负值:

这里是随便拿了一层来看,看右边两个图的grad就会发现梯度基本为0
到了这里结论其实很清楚了,实际上发生的问题是一开始使用了过大的学习率,导致几乎全部节点都发生了dead relu,而在使用更小的学习率进行compile之后,原本的l1, l2, l3, l4中的权重和偏置并没有被重新初始化,这样以后就不会再有梯度了,再怎么改学习率或优化器都无法收敛了。

当然为了方便描述,在这个例子里边用了5这么一个几乎不会用的学习率。推广到更一般的情况,如果使用的学习率过大,导致线性激活函数死掉,或者导致其它激活函数进入了离原点很远的饱和区间,那后边再如何修改参数都于事无补了。

在实际使用中,recompile很可能会在另一个场景给你错觉:ensemble是提升模型性能的一个方法,我们通过多次采样数据,使用不同的超参训练多个模型,最后使用这些模型预测的均值。比如我们用多次采样train+dev集合的方法,训练多次,但中间忘了重新初始化模型参数,而只是用一个compile,你会发现不论怎么采样,模型的在训练集的精度都是稳步上升的,但应用到测试集往往效果不怎么样。原因就是我们在训练过程中,越往后验证集中的数据就越有可能已经在之前的训练集中包含过,这样肯定性能会越来越好,但实际上只是过拟合了。

操作上的建议是:

上一篇 下一篇

猜你喜欢

热点阅读