【知识蒸馏】Deep Mutual Learning

2021-11-16  本文已影响0人  pprpp

【GiantPandaCV导语】Deep Mutual Learning是Knowledge Distillation的外延,经过测试(代码来自Knowledge-Distillation-Zoo), Deep Mutual Learning性能确实超出了原始KD很多,所以本文分析这篇CVPR2018年被接受的论文。同时PPOCRv2中也提到了DML,并提出了CML,取得效果显著。

引言

首先感谢:https://github.com/AberHu/Knowledge-Distillation-Zoo

笔者在这个基础上进行测试,测试了在CIFAR10数据集上的结果。

学生网络resnet20:92.29% 教师网络resnet110:94.31%

这里只展示几个感兴趣的算法结果带来的收益:

DML也是传统知识蒸馏的扩展,其目标也是将大型模型压缩为小的模型。但是不同于传统知识蒸馏的单向蒸馏(教师→学生),DML认为可以让学生互相学习(双向蒸馏),在整个训练的过程中互相学习,通过这种方式可以提升模型的性能。

DML通过实验证明在没有先验强大的教师网络的情况下,仅通过学生网络之间的互相学习也可以超过传统的KD。

如果传统的知识蒸馏是由教师网络指导学生网络,那么DML就是让两个学生互帮互助,互相学习。

DML

小型的网络通常有与大网络相同的表示能力,但是训练起来比大网络更加困难。那么先训练一个大型的网络,然后通过使用模型剪枝、知识蒸馏等方法就可以让小型模型的性能提升,甚至超过大型模型。

以知识蒸馏为例,通常需要先训练一个大而宽的教师网络,然后让小的学生网络来模仿教师网络。通过这种方式相比直接从hard label学习,可以降低学习的难度,这样学生网络甚至可以比教师网络更强。

Deep Mutual Learning则是让两个小的学生网络同时学习,对于每个单独的网络来说,会有针对hard label的分类损失函数,还有模仿另外的学生网络的损失函数,用于对齐学生网络的类别后验。

image

这种方式一般会产生这样的疑问,两个随机初始化的学生网络最初阶段性能都很差的情况,这样相互模仿可能会导致性能更差,或者性能停滞不前(the blind lead the blind)。

文章中这样进行解释:

image

DML具有的特点是:

DML中使用到了KL Divergence衡量两者之间的差距:

D_{K L}\left(\boldsymbol{p}_{2} \| \boldsymbol{p}_{1}\right)=\sum_{i=1}^{N} \sum_{m=1}^{M} p_{2}^{m}\left(\boldsymbol{x}_{i}\right) \log \frac{p_{2}^{m}\left(\boldsymbol{x}_{i}\right)}{p_{1}^{m}\left(\boldsymbol{x}_{i}\right)}

P1和P2代表两者的逻辑层输出,那么对于每个网络来说,他们需要学习的损失函数为:

\begin{aligned} &L_{\Theta_{1}}=L_{C_{1}}+D_{K L}\left(\boldsymbol{p}_{2} \| \boldsymbol{p}_{1}\right) \\ &L_{\Theta_{2}}=L_{C_{2}}+D_{K L}\left(\boldsymbol{p}_{1} \| \boldsymbol{p}_{2}\right) \end{aligned}

其中L_{C_{1}},L_{C_{2}}代表传统的分类损失函数,比如交叉熵损失函数。

可以发现KL divergence是非对称的,那么对两个网络来说,学习到的会有所不同,所以可以使用堆成的Jensen-Shannon Divergence Loss作为替代:

\frac{1}{2}\left(D_{K L}\left(\boldsymbol{p}_{1} \| \boldsymbol{p}_{2}\right)+D_{K L}\left(\boldsymbol{p}_{1} \| \boldsymbol{p}_{2}\right)\right)

更新过程的伪代码:

image

更多的互学习对象

给定K个互学习网络,\Theta_{1}, \Theta_{2}, \ldots, \Theta_{K}(K \geq 2), 那么目标函数变为:

L_{\Theta_{k}}=L_{C_{k}}+\frac{1}{K-1} \sum_{l=1, l \neq k}^{K} D_{K L}\left(\boldsymbol{p}_{l} \| \boldsymbol{p}_{k}\right)

将模仿信息变为其他互学习网络的KL divergence的均值。

扩展到半监督学习

在训练半监督的时候,我们对于有标签数据只使用交叉熵损失函数,对于所有训练数据(包括有标签和无标签)的计算KL Divergence 损失。

这是因为KL Divergence loss的计算天然的不需要真实标签,因此有助于半监督的学习。

实验结果

几个网络的参数情况:

image

在CIFAR10和CIFAR100上训练效果

image

在Reid数据集Market-1501上也进行了测试:

image

发现互学习目标越多,性能呈上升趋势:

image

结论

本文提出了一种简单而普遍适用的方法来提高深度神经网络的性能,方法是在一个队列中通过对等和相互蒸馏进行训练。

通过这种方法,可以获得紧凑的网络,其性能优于那些从强大但静态的教师中提炼出来的网络。
DML的一个应用是获得紧凑、快速和有效的网络。文章还表明,这种方法也有希望提高大型强大网络的性能,并且以这种方式训练的网络队列可以作为一个集成来进一步提高性能。

参考

https://github.com/AberHu/Knowledge-Distillation-Zoo

https://openaccess.thecvf.com/content_cvpr_2018/papers/Zhang_Deep_Mutual_Learning_CVPR_2018_paper.pdf

上一篇下一篇

猜你喜欢

热点阅读