推荐系统论文阅读(五十三)-基于多任务模型的蒸馏召回模型
论文:
![](https://img.haomeiwen.com/i20689929/0e839622e4b92ed7.png)
论文题目:《Distillation based Multi-task Learning: A Candidate Generation Model for Improving Reading Duration》
论文地址:https://arxiv.org/pdf/2102.07142.pdf
中间因为工作原因断更了几个月,今天我们来重新阅读一篇关于蒸馏的论文吧,刚好笔者最近一直在做蒸馏相关的任务。
一 、背景
在qq看点中,用户可能会被一些标题党吸引住,这就造成了用户的点击率很高,但是浏览体验不是很好。通常要解决这种任务,就会同时对ctr任务和浏览时长任务进行建模,也就是用多任务学习的框架来建模。
多任务建模有两个挑战:
1)那些没有发生点击行为的样本,也就是0阅读时长的样本,这些样本没有被电击过,我们不能简单的当成负样本,有可能是用户压根就没有见过这些样本。如果简单把这些样本全部当成负样本,可能就会导致有偏的预估。
2)在召回阶段,一般很少进行多任务的学习,尤其是双塔模型,因为召回阶段通常是ann的方式生成候选集合,多个任务难以使用ann,多任务模型的如何利用同一份双塔来建模多个任务呢?这也是一个挑战。
为了解决以上两个问题,qq看点就提出了基于多任务模型框架的蒸馏模型,并运用到召回任务中,下面我们一起来看看细节吧。
二 、模型结构
先来看看模型的整体结构:
![](https://img.haomeiwen.com/i20689929/c76843b726583b96.png)
熟悉蒸馏的人都知道,如果要在粗排或者召回任务里去使用蒸馏,有几种方式,第一种是精排模型同召回一起训练,然后用精排的预估分数作为softlabel来蒸馏召回的logits,第二种是精排模型不跟召回一起训练,而是通过线上直接把打点的分数埋下来后在来蒸馏召回。
两种方法都可以,如果说想节约训练时间,可以直接同时训练两个模型,也可以用两阶段训练的方式。
我们现在回过来看这个模型结构图,先来看左边的网络,左边是精排的多任务模型,采用的结构是MMoE,关于MMoE,这里不在进行细说,感兴趣的可以自己去搜一下之前的文章。
精排模型建模的是ctr任务和cvr任务,这qq看点中的cvr任务被定义为浏览时长,也就是说浏览时长的建模这时候并不是一个回归任务了,而是分类任务,正样本选取浏览时间超过30s的样本,负样本选取浏览时长小于30s的样本。也就是说我们通过一个阈值,把用户的浏览时长任务转变成了分类任务,所以就有了模型结构中的Pctr和Pcvr这两个分数。
其实精排模型的思想跟ESMM类似,通过建模ctr和ctcvr任务来间接的优化cvr任务,看点这篇文章也是通过这个方法来建模的,关于ESMM和MMoE这里不在进行赘述,直接给出下面的公式:
ctr的向量
![](https://img.haomeiwen.com/i20689929/785c3cc7b7c3df1c.png)
cvr的向量:
![](https://img.haomeiwen.com/i20689929/1434ebf551dbcfa3.png)
ctr和cvr的分数:
![](https://img.haomeiwen.com/i20689929/a1825c91d13a1769.png)
ctcvr分数的计算:
![](https://img.haomeiwen.com/i20689929/6392a746b74c27bb.png)
模型的浏览时长任务loss:
![](https://img.haomeiwen.com/i20689929/ed467b3030eeea97.png)
模型的ctr任务的loss:
![](https://img.haomeiwen.com/i20689929/c2f9270bf1f4cd84.png)
模型的总loss:
![](https://img.haomeiwen.com/i20689929/ce9d57371e3b0e16.png)
好了,现在我们回到整个模型的框架里,左边是teacher网络,右边呢自然是student网络,student网络是一个双塔结构,前面我们已经说了,在召回模型里,很难来直接建模多个任务,一般都是ctr跟cvr两个双塔模型。如果,我们能够将精排模型的ctcvr分数直接对召回模型的logits直接进行蒸馏,这样就能让召回模型学习到了点击和浏览时长这两个任务来。
有一个重要的点,我们必须要明确,蒸馏必须是在同一个人任务里做蒸馏,也就是说我们的teacher的分数如果是ctcvr的话,那么召回的任务也只能是ctcvr,不能拿ctcvr的分数去蒸馏召回的ctr分数。所以可以看到召回分数也是ctcvr的预测分数,label是精排的logits,loss计算为:
![](https://img.haomeiwen.com/i20689929/3c7ac3f1df302e82.png)
ps:我觉得上面的损失好像少个log?
模型的总loss为:
![](https://img.haomeiwen.com/i20689929/1405c2aa75bdf171.png)
三、实验结果
![](https://img.haomeiwen.com/i20689929/ff3eaf992dea27e9.png)
![](https://img.haomeiwen.com/i20689929/8f6e23d20bfef6d5.png)