因果推断推荐系统工具箱 - CASR(三)
文章名称
【SIGIR-2021】【Beijing Key Laboratory of Big Data Management and Analysis Methods】Counterfactual Data-Augmented Sequential Recommendation
核心要点
文章旨在利用反事实数据生成的方法,解决Session-Based推荐场景下数据稀疏限制模型性能的问题。作者提出CASR框架,包含3种增广序列生成的sampler model,以及传统的序列推荐模型anchor model用于生成最终的推荐列表。Sampler model是整个框架的关键,其中一种是简单的随机替换模型,另外两种是面向数据和面向模型的序列生成模型。面向数据的模型会生成处在决策边界附近的反事实序列,而面向模型的方法则会以最大化提供给anchor model的信息为目标生成反事实序列。
前两节节介绍了,文章需要解决的额问题背景,方法框架,3种sampler model以及anchor model。本节总结一下方法的实现流程并给出与sampler model相关的误差分析。
方法细节
问题引入
上一节讲解了2种learning-based sampler model以及如何与anchor model进行结合。那么整体训练过程是什么样的呢?
此外,当前交互物品(也就是标签物品)是通过优化如下图所示的公式得到的,sampler model 可能并没有达到比较好的效果,这样产生的反事实样本不可避免的存在噪声,会影响anchor model 的训练。那么这个噪声会多大程度上影响模型效果呢?
current item decided by sampler model具体做法
Traning Pipeline
首先,我们总结一下,如何训练sampler model和anchor model,
- 观测数据阶段。在观测数据上,同时预训练sampler model和anchor model。
- 反事实采样阶段。首先,分别采用3种sampler model(或其中几种)生成反事实样本。在确定好需要生成的反事实样本数量和需要替换的序列index 之后,利用不同的sampler寻找替换样本以及当前交互样本。例如,
- 简单的sampler,直接随机寻找替换样本,此时,随后利用上述优化公式2,得到。
- 面向数据的sampler,优化如下图所示的公式,得到,再投影到,随后利用上述优化公式2,得到。值得注意的是,如前所述,需要确认。
Differentiable Data-oriented counterfactual sequence Loss- 面向模型的sampler,优化如下图所示的公式,得到,再投影到,随后利用上述优化公式2,得到。
Differentiable Model-oriented counterfactual sequence Loss- 反事实训练阶段。得到反事实样本后,将其与观测样本结合训练anchor model 。作者表示,面向模型的训练方法隐含了对抗训练的思想在里边。因为,现寻找反事实样本让模型的损失增大,随后在训练模型让损失减小,是一个minmax game。
整个训练的pipeline伪代码参见代码部分。
Theoretical Analysis
如前所述,通过sampler model 生成的反事实样本中的当前交互物品可能存在噪声。作者利用PAC理论[1]分析了噪声的影响,回答”给定采样器模型的噪声水平,需要多少个样本才能获得足够好的性能?“。
定义表示是在反事实样本序列下,会真实出现的当前交互样本的标签。正确估计的概率为。那么,如果则可以生成完全合理的反事实样本,如果如果则类似随机采样。
作者可以得到如下图所示的结论,假设anchor model 的预测误差是,那么观测和预测之前不相符的原因可以分为两个部分,
- 观测(生成的反事实是对的),预测错了,概率为;
- 观测(生成的反事实是错的),预测对了,概率为;
因此,整体的不相符概率为,作者利用反正法进行证明。
Theorem1证明的思路大致是,如果的预测误差是,那么两个条件必须同时满足,
- 经验风险最小化得到的模型的预测误差大于
如果这两个条件不能同时满足,则Theorem1得证。
利用uniform convergence properties[1],得到如下引理论,
Lemma 2.1有了Lemma2.1,我们再来看上述两个条件。
- 条件一
- 因为的预测误差是,所以期望损失。
- 又因为,所以应该小于。
- 但根据Lemma2.1,如果,所以。
- 条件二
- 如果模型的期望损失且经验风险损失,那么应该小于。
- 同样根据Lemma2.1,如果,所以。
因此,上述两个条件成立的概率都小于。因此,的预测误差小于的概率大于,Theorem1得证。
可以看出只要反事实样本量,可以保证模型的预测误差是在可接受范围内(也就是噪声越大需要的反事实样本越大,才能掩盖住噪声...)。
从理论分析中看,random sampler model 需要接近无穷多的样本。同时,收到分析的启发,作者引入超参数,是的只有公式2优化的目标大于的反事实样本被生成出来(sampler model足够自信的反事实样本)。但是需要平衡反事实样本少带来的偏差和噪声带来的误差(详情参见论文)。
代码实现
文章算法pipeline的伪代码如下图所示。
Learning Algorithm of CASR心得体会
Adversial Training
作者提到Model-oriented sampler model是类似对抗训练的思路。个人理解,这算是所有反事实样本生成训练的一个核心卖点,因为本身反事实样本是在做增广,对抗训练也在做增广。反事实是找和真实样本比较接近的对比样本,对抗则是专门供给模型的弱点(感觉,对比学习是利用数据或领域固有的特点进行对抗)。所以,是在用不同的方法榨取数据信息。
文章引用
[1] Shai Shalev-Shwartz and Shai Ben-David. 2014. Understanding machine learning: From theory to algorithms. Cambridge university press.