论文阅读_知识蒸馏_Distilling_BERT
英文题目:Distilling Task-Specific Knowledge from BERT into Simple Neural Networks
中文题目:从BERT中蒸馏指定任务知识到简单网络
论文地址:https://arxiv.org/pdf/1903.12136.pdf
领域:自然语言,深度学习
发表时间:2019
作者:Raphael Tang, 滑铁卢大学
被引量:226
代码和数据:https://github.com/qiangsiwei/bert_distill
阅读时间:2022.09.11
读后感
第一次对大型自然语言模型的蒸馏:将BERT模型蒸馏成BiLSTM模型。
介绍
在自然语言处理方面,随着BERT,GPT等大规模预训练模型的发展,浅层的深度学习模型似乎已经过时了。但由于资源的限制,又需要使用小而快的模型。
文章的动机是讨论:浅层模型是否真的不具备对文本的表示能力?并展示了针对于具体的任务,将BERT蒸馏成单层BiLSTM模型的方法和效果。也通过大模型(起初训练的复杂的模型,后称Teacher/T)和小模型(蒸馏后的模型,后称Student/S)完全不同的模型结构展示了蒸馏与模型结构无关。另外,之前蒸馏模型主要应用于图片建模 ,论文讨论了它在自然语言领域的使用方法。
方法
核心方法包含两部分:增加了logit回归目标;重建蒸馏训练数据集使训练更为有效。
模型结构
将BERT作为教师模型,使用单层的 BiLSTM 作为学习模型的非线性分类器,针对每一种下游任务使用不同模型。如图-1是对单句分类任务设计的学生模型。
图-2展示了用于预测句子匹配度的模型,它们的编码层共享同一BiLSTM模型。
为了更好地对比效果,在学生模型中,未使用注意力归一化等更多技巧。
蒸馏目标
学习模型的目标是在所有数据上,模拟老师模型的行为。除了最终的标签,老师模型预测出的概率也很重要 。比如在情绪分类问题中,一些实例有很强的正面情绪,有一些情绪可能比较中性,所以除了是否,也需要预测程度。
一般预测标签的方法是:
文中使用了logit的优化方法,构造了蒸馏目标:用MSE来惩罚师生模型间的差异:
其中z(B)指的是老师模型BERT,z(S)指学生模型,在初步实验中,MSE比软目标效果更好。
在实际训练时,也使用了传统的交叉熵(对真正目标的预测)和蒸馏损失相结合的方式,最终损失函数如下:
当使用有标签数据训练时,t是实例的标签;使用无标签数据训练时,使用老师模型打标签。
蒸馏的数据增强
在蒸馏过程中,使用小的数据集不足以让老师模型展示出其所有知识,因此,使用了无标签数据扩充训练数据集,用老师模型对其打标签。
增强NLP数据比增强图像数据难度大,没办法使用扭曲等方法,做出的句子可能不够流畅。文中提出了几种数据增强方法:
- 遮蔽:使用类似BERT的方法,这种方法能反应句中每个词对标签的贡献。
- 基于词性的词替换:在词袋里找同一词性的词作替换,以保持原始数据的分布。
- n-gram采样:根据概率,随机采样n个连续的词,它是遮蔽方法的增强版。
实验
使用的是BERT_LARGE作为老师模型,针对特定任务精调,预测时获取预测的logit值,学生模型使用300维的word2vec作为词嵌入。主实验效果如表-1所示:
可以看到同样是使用BiLSTM方法,文中方法相较于其它方法有显著提升。
从表-2可以看到预测速度也有很大提升: