因果推断推荐系统工具箱 - CASR(二)
文章名称
【SIGIR-2021】【Beijing Key Laboratory of Big Data Management and Analysis Methods】Counterfactual Data-Augmented Sequential Recommendation
核心要点
文章旨在利用反事实数据生成的方法,解决Session-Based推荐场景下数据稀疏限制模型性能的问题。作者提出CASR框架,包含3种增广序列生成的sampler model,以及传统的序列推荐模型anchor model用于生成最终的推荐列表。Sampler model是整个框架的关键,其中一种是简单的随机替换模型,另外两种是面向数据和面向模型的序列生成模型。面向数据的模型会生成处在决策边界附近的反事实序列,而面向模型的方法则会以最大化提供给anchor model的信息为目标生成反事实序列。
上一节介绍了,文章需要解决的额问题背景,方法框架以及简单的sampler model。本节继续介绍更合理的sampler model以及anchor model。
方法细节
问题引入
上一节提到,基于启发式的反事实序列生成方法(优化目标如下图所示)太简单了,并且引入过多随机性,造成效果不理想。
Heuristic Sampler Loss
回顾一下,序列化推荐的形式化定义,
- 用户集合为
,物品集合为
;
- 对用户
的推荐结果为
,其反事实推荐结果为
,表示修改用户行为后,可以被替换的物品;
- 用户
的历史行为集合记作
,其中
表示用户在
时刻交互的物品;
- 序列推荐模型的目标是基于
(所有训练数据),准确预测每一个用户的下一个物品
(以及用户未来的偏好);
- 序列推荐模型可以用
表示。通常通过优化如下目标,来得到模型参数。在大规模物品集合的场景下,一般采用负采样或sample softmax的方法减少计算复杂度,其中,
为负采样的行为序列,若
出现在用户历史行为中,则
为1,否则为0。
crossentropy loss
我们期望,能够生成信息丰富,但更具有指向性的反事实是序列,来训练推荐模型。
具体做法
为了解决上述问题,作者提出两种可学习的序列生成方法,面向数据的和面向模型的方法,两种方法的流程架构如下图所示。
Learning-Based Sampler
Data-oriented counterfactual sequence learning
在具有标签的数据集上,可以利用标签把训练数据划分为多个部分,各部分(间)的边界称之为decision boundaries。[1, 2]的研究表明,decision boundaries附近的样本通常在揭示底层数据模式方面具有辨别力,基于它们训练能够提高模型性能。
基于这个思路,作者通过最小改动生成反事实行为序列来恰好改变模型的预测结果(看过之前反事实解释文章的同学应该知道,这就是所谓的counterfactual explanation,详情可以参见因果可解释推荐系统工具箱 - CountER(一)和因果可解释推荐系统工具箱 - ACCENT(一)),并利用这些反事实序列训练模型。
不同于前面讲过的反事实解释方法,作者是通过在隐向量空间,对特定目标进行优化,来生成反事实序列的,具体优化目标如下图所示。其中是物品
在向量空间中的表示(这里作者没有明确说明,我没理解错的话,是Sampler Model空间中的表示,当然也可以让Sampler Model和Anchor Model共享底层表示空间,只是作者没有详细区分)。
Raw Data-oriented counterfactual sequence Loss
同上一节Heuristic Sampler方法一样,该方法的目标仍然是寻找一个物品代替用户历史行为序列中,特定序号
的物品
。只是,此时的选择方法并非完全随机,而是通过优化上述目标得到的。同时,需要满足改动后,目标物品与现在(真实序列中)的目标物品不同的约束。实际上可以理解为,作者对用户的真实行为序列做了两个地方的改动,1)改变
成为
,2)由于
,导致
改变。而,这个新的
被记作
,是通过优化约束中的
得到的(也就是问题引入中提到的公式2)。
优化目标,保证了替换物品与原始物品
足够相似,你一定懂了,这也保证这种改动是最小改动。
作者提到,的候选物品
可以是利用先验知识选择的物品自己,也可以就是物品的全集
(当然也取决于物品集合的大小)。此外,如果
和
的差距太小,不能够改变
,那说明生成的反事实序列在决策边界上。
到这里还没完。上述的优化目标是不可导的。
因此,作者重写了优化目标,在隐空间中先寻找到近似的向量,再投影到具体物品的向量上。具体地说,作者提出了虚拟的的概念,其隐向量表示
。
是连续可变的可学习参数,通过优化如下目标,可以利用可导的方法,学习到
,然后再投影到真实物品上。
Differentiable Data-oriented counterfactual sequence Loss
上述优化目标的
- 第一项,保证
与原始物品
足够相似。
- 第二项,告诉模型,当前真实的
物品不是我们期望的(通过给模型增加
出现概率的惩罚项实现),类似于原始优化问题的约束项。
是超参数。
得到可以利用如下优化方法,把
投影到真实的物品上。
Projection to Real Item
Model-oriented counterfactual sequence learning
除了从数据决策边界的角度生成反事实序列,作者还借鉴了[3,4]的思想,提出了面向模型的反事实行为序列生成方法。基本思想史,寻找能够为模型提供较大损失的样本,因为此类样本意味着模型没有学习好(想想Boosting的做法),能够提供更多信息,提升模型性能。
因此,作者通过最大化anchor model的损失来生成反事实序列。同样用替换
,只不过优化的目标如下图所示。可以看到,这个优化目标中引入了anchor model
。第一个约束,保证生成的样本是由sampler model给出的,第二个约束,保证在sampler model的向量表示空间中
和
的表示足够接近,接近程度通过
控制。
同样上述优化目标不可导,仍然利用虚拟的,来寻找适合的替换物品。作者定义
表示用户(在sampler和anchor模型两个空间,如果
就是anchor模型空间)的真实序列或者是替换后的行为序列。因此,可以通过优化如下图所示的目标来生成反事实序列。
Differentiable Model-oriented counterfactual sequence Loss
作者利用softmax来产生候选的。其中
用来控制是soft的选择还是hard(在softmax中常用的温度函数,
趋于0,平均选择各个物品,
趋于无穷,类似于
)。
上述优化目标的
- 第一项,是在估计sampler model的各种
选择下,anchor模型的平均损失(因为一种行为序列修改修改,sampler模型可能给出不同的
,按概率平均加权求和就得到了anchor模型损失的期望)。
- 第二项,保证
与原始物品
足够相似。
是超参数。
值得注意的是,作者表示
- Data-oriente是自底向上的,模型无关的;
- Model-oriented是自顶向下的,和业务结合的,依赖于模型。
两者有不同的角色,可以在不同场景中应用。
本节介绍了两种learning-based的sampler model,anchor model其实就是普通的序列推荐模型。下节介绍如何学习模型参数以及作者进行的理论分析。
心得体会
反事实解释
从文章的内容可以看出,反事实解释除了可以用来对模型的预测结果进行解释,帮助客户理解模型推荐的理由,增加对推荐系统的信赖度。同时,可以帮助开发人员进行debug。此外,还可以作为数据增广的途径和方式。
脑洞更大一点,反事实解释可以在决策边界内外游走,如果具有较好的可控制性,我们可以进行对抗,增加鲁邦。同时,控制我们的模型和策略。
大损失样本
个人感觉,在利用大损失样本之前,首先需要进行数据去噪。一般噪声样本会有较大的损失,不能够有效指导模型训练。
替换数量
从全文看,作者只在真实行为序列中替换了一个物品,即便替换两个物品的平均改动可能更小,也仍然寻找一个。这是方法决定的,与其他可以替换多个物品(物品集合)的counterfactual explanation的方法不同。
文章引用
[1] EhsanAbbasnejad,DamienTeney,AminParvaneh,JavenShi,andAntonvanden Hengel. 2020. Counterfactual vision and language learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 10044–10054.
[2] YashGoyal,ZiyanWu,JanErnst,DhruvBatra,DeviParikh,andStefanLee.2019.Counterfactual visual explanations. arXiv preprint arXiv:1904.07451 (2019).
[3] Tsu-Jui Fu, Xin Eric Wang, Matthew F Peterson, Scott T Grafton, Miguel P Eckstein, and William Yang Wang. 2020. Counterfactual Vision-and-Language Navigation via Adversarial Path Sampler. In European Conference on Computer Vision. Springer, 71–86.
[4] Hongchang Gao and Heng Huang. 2018. Self-paced network embedding. In Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. 1406–1415.
crossentropy loss