知识蒸馏在推荐系统中的应用

2020-05-30  本文已影响0人  irving不会机器学习

一、知识蒸馏典型方法

做什么

解决复杂模型上线,模型响应速度太慢,当流量大的时候撑不住问题

如何做

可以将它学到的暗知识 ( Dark Knowledge ) 迁移给学习能力相对弱的 Student 模型,以此来增强 Student 模型的泛化能力,最后上线Student模型

两个技术发展主线

Logits 方法

logits是什么?

softmax层的输入,汇总了网络内部各种信息后,得出的属于各个类别的汇总分值 zi,就是 Logits。

实现

假设我们有一个 Teacher 网络,一个 Student 网络,输入同一个数据给这两个网络,Teacher 会得到一个 Logits 向量,代表 Teacher 认为输入数据属于各个类别的可能性;Student 也有一个 Logits 向量,代表了 Student 认为输入数据属于各个类别的可能性。最简单也是最早的知识蒸馏工作,就是让 Student 的 Logits 去拟合 Teacher 的 Logits,即 Student 的损失函数为:

Student 的损失函数由两项组成,一个子项是 Ground Truth,就是在训练集上的标准交叉熵损失,让 Student 去拟合训练数据,另外一个是蒸馏损失,让 Student 去拟合 Teacher 的 Logits:

H 是交叉熵损失函数,f(x) 是 Student 模型的映射函数,y 是 Ground Truth Label,zt 是 Teacher 的 Logits,zs 是 Student 的 Logits, ST() 是 Softmax Temperature 函数,λ 用于调节蒸馏 Loss 的影响程度。

特征蒸馏方法

强迫 Student 某些中间层的网络响应,要去逼近 Teacher 对应的中间层的网络响应。

二、知识蒸馏在推荐系统中的三个应用场景

精排

希望找到一个模型,这个模型既有较好的推荐质量,又能有快速推理能力。

我们在离线训练的时候,可以训练一个复杂精排模型作为 Teacher,一个结构较简单的 DNN 排序模型作为 Student。因为 Student 结构简单,所以模型表达能力弱,于是,我们可以在 Student 训练的时候,除了采用常规的 Ground Truth 训练数据外,Teacher 也辅助 Student 的训练,将 Teacher 复杂模型学到的一些知识迁移给 Student,增强其模型表达能力,以此加强其推荐效果。在模型上线服务的时候,并不用那个大 Teacher,而是使用小的 Student 作为线上服务精排模型,进行在线推理。因为 Student 结构较为简单,所以在线推理速度会大大快于复杂模型;而因为 Teacher 将一些知识迁移给 Student,所以经过知识蒸馏的 Student 推荐质量也比单纯 Student 自己训练质量要高。

模型召回以及粗排采用知识蒸馏

1. 用复杂的精排模型作为 Teacher,召回或粗排模型作为小的 Student,比如 FM 或者双塔 DNN 模型等,Student 模型模拟精排环节的排序结果,以此来指导召回或粗排 Student 模型的优化过程

2. 通过 Student 模型模拟精排模型的排序结果,可以使得前置两个环节的优化目标和推荐任务的最终优化目标保持一致,在推荐系统中,前两个环节优化目标保持和精排优化目标一致,其实是很重要的,但是这点往往在实做中容易被忽略

三、实现方法

精排环节蒸馏方法

目前推荐领域里,在精排环节采用知识蒸馏,主要采用 Teacher 和 Student 联合训练 ( Joint Learning ) 的方法,而目的是通过复杂 Teacher 来辅导小 Student 模型的训练,将 Student 推上线,增快模型响应速度。

对于 Student 网络来说,损失函数由两个部分构成,一个子项是交叉熵,这是常规的损失函数,它促使 Student 网络去拟合训练数据;另外一个子项则迫使 Student 输出的 Logits 去拟合 Teacher 输出的 Logits,所谓蒸馏,就体现在这个损失函数子项,通过这种手段让 Teacher 网络增强 Student 网络的模型泛化能力

这个模型是阿里妈妈在论文 "Rocket Launching: A Universal and Efficient Framework for Training Well-performing Light Net" 中提出的,其要点有三:其一两个模型同时训练;其二,Teacher 和 Student 共享特征 Embedding;其三,通过 Logits 进行知识蒸馏。对细节部分感兴趣的同学可以参考原始文献。

四、召回/粗排环节蒸馏方法

召回蒸馏的两阶段方法

在专门的知识蒸馏研究领域里,蒸馏过程大都采取两阶段的模式,就是说第一阶段先训练好 Teacher 模型,第二阶段是训练 Student 的过程,在 Student 训练过程中会使用训练好 Teacher 提供额外的 Logits 等信息,辅助 Student 的训练。


Logits 方案

在召回或者精排采用知识蒸馏,此时,精排模型其实身兼二职:主业是做好线上的精准排序,副业是顺手可以教导一下召回及粗排模型。所以,其实我们为了让 Teacher 能够教导 Student,在训练 Student 的时候,并不需要专门训练一遍 Teacher 精排模型,因为它就在线上跑着呢。

Without-Logits 方案

另外一类方法可以进一步减少 Student 对 Teacher 的依赖,或适用于无法得到合理 Logits 信息的场合,即 Student 完全不参考 Logits 信息,但是精排作为 Teacher,怎么教导 Student 呢?别忘了,精排模型的输出结果是有序的,这里面也蕴含了 Teacher 的潜在知识,我们可以利用这个数据。也就是说,我们可以让 Student 模型完全拟合精排模型的排序结果,以此学习精排的排序偏好。我们知道,对于每次用户请求,推荐系统经过几个环节,通过精排输出 Top K 的 Item 作为推荐结果,这个推荐结果是有序的,排在越靠前的结果,应该是精排系统认为用户越会点击的物品。

五、联合训练召回、粗排及精排模型的设想

如果我们打算把知识蒸馏这个事情在推荐领域做得更彻底一点,比如在模型召回、粗排以及精排三个环节都用上,那么其实可以设想一种"一带三"的模型联合训练方法。

如上图所示,我们可以设计一个很复杂但是效果很好的排序模型作为 Teacher,然后和召回、粗排、精排三个 Student 联合训练,精排 Student 可以使用 Logits 以及隐层特征响应等各种手段优化,追求效果好前提下的尽可能速度快,召回和粗排 Student 则追求在模型小的前提下追求效果尽可能好。因为排序 Teacher 比较复杂,所以能够提供尽可能好的模型效果,通过它来带动三个环节蒸馏模型的效果,而模型速度快则是蒸馏方法的题中应有之意。

六、问题、讨论

七、参考文献

https://zhuanlan.zhihu.com/p/143155437

https://mp.weixin.qq.com/s/4OE8DqIVjb6PKxzhyQnMqA

上一篇下一篇

猜你喜欢

热点阅读