读书人工智能

GCA:基于自适应数据增强的图对比学习

2022-05-11  本文已影响0人  酷酷的群

论文标题:Graph Contrastive Learning with Adaptive Augmentation
论文链接:https://arxiv.org/abs/2010.14945
论文来源:WWW 2021

一、概述

图对比学习中的数据增强在近来的方法中被证明是一个关键的部分,然而对于图数据增强的方法的研究却是不充分的。对于图像和文本来说,数据增强有很多种方式,然而对于图数据来说,数据增强是不容易的,这是由图数据的非欧几里得特性引起的。本文认为过去的图数据增强方法有两个缺点:
①简单的数据增强,比如DGI中的特征打乱,对于生成节点多样化的邻域(也就是上下文)是不充分的,尤其是节点特征较为稀疏时,会导致对比目标函数的优化是困难的;
②在执行数据增强时忽略了不同节点和不同边的影响。

数据增强应该保留原始数据最本质的特征,举例来说,对一张猫的图片进行数据增强(比如对其进行翻转),应该使得增强后的图仍然是一只猫,而不能使其丢失其作为猫的本质特征,这样模型才可以通过对比来学习到猫的特征。同样的,如果在对图通过随机删除边的方式进行数据增强时,某些重要的边被删除掉就会影响最终学习到的embedding的质量,也就是说图的不同节点和不同边在数据增强时产生的影响是不同的,在数据增强时应该尽可能的保留重要的边以及重要的节点特征,而一些图数据增强采用随机处理的方式,这样势必造成性能的损伤。

对比学习得到的表示应该对通过数据增强引入的破坏具备一定的不变性,因而数据增强策略应自适应输入的图,以反映其内在模式。同样的以删除边的数据增强方式为例,我们应该给不重要的边以大的移除概率,给重要的边以小的移除概率。然后,该方案能够引导模型忽略不重要边上引入的噪声,从而学习输入图下的重要模式。

本文提出了Graph Contrastive learning with Adaptive augmentation(GCA)框架来利用自适应的数据增强策略进行图的节点表示学习,整体框架图如下:

框架

本文采用的自适应图数据增强的主要思想是给不重要的边以更大的移除概率,给不重要的节点特征维度以更大的mask概率。

二、方法

  1. 定义

使用\mathcal{G}=(\mathcal{V},\mathcal{E})代表一个图,\mathcal{V}=\left \{v_{1},v_{2},\cdots ,v_{N}\right \},\mathcal{E}\in \mathcal{V}\times \mathcal{V}代表节点集合以及边的集合。使用X\in \mathbb{R}^{N\times F}代表节点特征矩阵,使用A\in \left \{0,1\right \}^{N\times N}代表邻接矩阵,x_{i}\in \mathbb{R}^{F}是节点v_i的特征,当(v_{i},v_{j})\in \mathcal{E}A_{ij}=1。我们的目标是学习一个GNN encoderH=f(X,A)\in \mathbb{R}^{N\times F^{'}},提供低维节点表示,也就是F^{'}\ll F,使用h_{i}表示学习到的v_i的表示,这些表示将被用于下游任务。本文实验中采用的encoder为两层GCN。

  1. 对比学习框架

按照GCA的框架,每次迭代时采样两个随机增强函数t\sim \mathcal{T}t^{'}\sim \mathcal{T}\mathcal{T}是所有可能的增强函数的集合。然后获得图的两个视图\tilde{\mathcal{G}}_{1}=t(\mathcal{G})以及\tilde{\mathcal{G}}_{2}=t^{'}(\mathcal{G}),学习到的节点表示为U=f(\tilde{X}_{1},\tilde{A}_{1})以及V=f(\tilde{X}_{2},\tilde{A}_{2}),这里的\tilde{X}_{*},\tilde{A}_{*}是视图的特征矩阵和邻接矩阵。

接着使用得到的节点表示来进行对比学习的过程。对于节点v_i关于两个个视图的表示分别为\textbf{u}_{i}\textbf{v}_{i},正样本为\textbf{u}_i\textbf{v}_i的组合,负样本为\textbf{u}_{i}\textbf{v}_{i}分别与两个视图中其他节点的表示的组合。对于正样本对(\textbf{u}_{i},\textbf{v}_{i}),采用InfoNCE损失:

\ell(\textbf{u}_{i},\textbf{v}_{i})=\frac{e^{\theta (\textbf{u}_{i},\textbf{v}_{i})/\tau }}{\underset{\mathrm{positive\: pair}}{\underbrace{e^{\theta (\textbf{u}_{i},\textbf{v}_{i})/\tau }}}+\underset{\mathrm{inter-view\: negative\: pairs}}{\underbrace{\sum _{k\neq i}e^{\theta (\textbf{u}_{i},\textbf{v}_{k})/\tau }}}+\underset{\mathrm{intra-view\: negative\: pairs}}{\underbrace{\sum _{k\neq i}e^{\theta (\textbf{u}_{i},\textbf{u}_{k})/\tau }}}}

这里的\tau是一个温度超参数。本文定义discriminator为\theta (\textbf{u},\textbf{v})=s(g(\textbf{u}),g(\textbf{v})),这里s(\cdot ,\cdot )是余弦相似度,g(\cdot )是一个非线性变换(采用两层MLP),用来增强discriminator的表达能力。最终需要最大化的目标函数为:

\mathcal{J}=\frac{1}{2N}\sum_{i=1}^{N}[\ell(\textbf{u}_{i},\textbf{v}_{i})+\ell(\textbf{v}_{i},\textbf{u}_{i})]

整个模型的算法如下:

算法
  1. 自适应图数据增强

本文研究的重点在于自适应的图数据增强方法,希望数据增强在扰乱不重要的连接和特征时能够保持重要的结构和属性不变。本文采用随机删除边和mask节点属性的方式来进行数据增强,但是删除和mask的概率应该向不重要的边或特征倾斜,也就是对于不重要的边或特征删除或mask的概率要大,对于重要的要小,这是本文数据增强方法设计的基本思想。

对于图的拓扑结构,本文考虑采用随机移除边的方式来进行数据增强,具体的,就是从原始边集合\mathcal{E}中采样一个子集\tilde{\mathcal{E}},采样过程依照以下概率:

P\left \{(u,v)\in \tilde{\mathcal{E}}\right \}=1-p_{uv}^{e}

这里的p_{uv}^{e}是移除边(u,v)的概率,\tilde{\mathcal{E}}就是生成的视图的边集合。p_{uv}^{e}应该能够反映边的重要性,对于重要的边p_{uv}^{e}应该要小于不重要的边。

在网络科学中,节点中心性(node centrality)是度量节点影响力时广泛使用的度量。我们通过边的两个节点的中心性来定义边(u,v)的中心性w_{uv}^{e}。给定一个节点中心性度量\varphi _{c}(\cdot ):\mathcal{V}\rightarrow \mathbb{R}^{+},边的中心性定义为两个节点中心性的平均:

w_{uv}^{e}=(\varphi _{c}(u)+\varphi _{c}(v))/2

在有向图上,我们简单地使用尾节点的中心性,也就是w_{uv}^{e}=\varphi _{c}(v),因为边的重要性通常由被指向的节点决定。接下来基于边的中心性来获得其移除的概率。首先,由于节点中心性(比如采用节点度时)的数值可能跨越多个数量级,因此设置s_{uv}^{e}=log\: w_{uv}^{e}来缓解连接密集的节点的影响。接着概率通过一个标准化的过程来获得:

p_{uv}^{e}=min\left (\frac{s_{max}^{e}-s_{uv}^{e}}{s_{max}^{e}-u_{s}^{e}}\cdot p_{e},p_{\tau }\right )

这里的p_{e}是一个超参数来控制移除边的总体概率,s_{max}^{e}u_{s}^{e}s_{uv}^{e}的最大值和平均值,并且p_{\tau }<1,是一个截断概率,用来防止过高的移除概率导致对图的过分破坏。

对于节点中心性的度量,采用以下三种:度中心性、特征向量中心性以及PageRank中心性。这三种度量是简单而高效的。

度中心性 节点的度本身可以作为中心性的度量。在有向图中采用节点的入度。尽管节点度是最简单的中心性度量之一,但它非常有效且具有启发性。这一度量背后的假设是重要的节点就是拥有许多连接的节点。

特征向量中心性 节点的特征向量中心性定义为对应于邻接矩阵的最大特征值的特征向量,具体的,有A\xi =\lambda \xi,这里的\lambda是邻接矩阵的最大特征值,则节点v_i的特征向量中心性就是\xi _i。特征向量中心性的基本思想是,一个节点的中心性是相邻节点中心性的函数。也就是说,与你连接的人越重要,你也就越重要。不同于度中心性,度中心性假设所有邻居对节点的重要性贡献相等,特征向量中心性也考虑了邻居节点的重要性。由于A\xi =\lambda \xi,那么特征向量中心性向量\xi就可以表示为\xi =\lambda ^{-1}A\xi,那么节点v_i的中心性就是:

EC(v_{i})=\xi _{i}=\lambda ^{-1}\sum _{j=1}^{N}a_{ij}\xi _{j}

上面的式子表明节点v_i的中心性相当于对其邻居节点的中心性做了平均。当一个节点与很多节点相连或者与高影响力的节点相连时会有比较高的节点中心性。在有向图上,我们使用右特征向量来计算中心性,它对应节点的入边。注意,由于只需要最大特征值的特征向量,计算特征向量中心性的计算负担是可以忽略的。

PageRank中心性 PageRank中心性定义为PageRank算法计算得到的PageRank权重。该算法将影响沿有向边传播,将聚集的影响最大的节点视为重要节点。具体的,中心性数值定义为:

\sigma =\alpha AD^{-1}\sigma +1

这里\sigma \in \mathbb{R}^{N}是PageRank中心性得分向量,\alpha是一个阻尼因子,设置为\alpha =0.85。对于无向图,我们通过将一条无向边转换为两条有向边来将其转换为有向图。

下图展示了3种不同的中心性度量的应用:

中心性

我们通过随机mask一些维度来为节点特征添加噪声。首先我们采样一个随机向量\tilde{m}\in \left \{0,1\right \}^{F},每个维度由伯努利分布独立采样得到,也就是\tilde{m}_{i}\sim Bern(1-p_{i}^{f}),\forall i,然后生成的节点特征\tilde{X}为:

\tilde{X}=[x_{1}\circ \tilde{m};x_{2}\circ \tilde{m};\cdots ;x_{N}\circ \tilde{m}]^{T}

这里的[\cdot ,\cdot ]是拼接操作,\circ是哈达玛积。

类似拓扑结构的数据增强,节点属性维度的mask概率p_{i}^{f}应该反映第i个维度的重要性,我们假设在重要的节点中频繁出现的维度是重要的,以此来定义特征维度的权重。对于稀疏one-hot节点特征,比如x_{ui}\in \left \{0,1\right \},维度i的权重计算为:

w_{i}^{f}=\sum _{u\in \mathcal{V}}x_{ui}\cdot \varphi _{c}(u)

这里的\varphi _{c}(\cdot )是一种节点中心性度量。上面式子中第一项x_{ui}\in \left \{0,1\right \}表示维度i是否在节点u中出现,第二项衡量每次出现的节点重要性。举个直观的例子,考虑一个引用网络,其中每个特征维度对应一个关键字,在一篇影响力很大的论文中频繁出现的关键词应该被认为是信息丰富和重要的。

对于稠密连续的节点特征,采用以下方式:

w_{i}^{f}=\sum _{u\in \mathcal{V}}|x_{ui}|\cdot \varphi _{c}(u)

类似的,以标准化的方式获得概率:

p_{i}^{f}=min\left (\frac{s_{max}^{f}-s_{i}^{f}}{s_{max}^{f}-u_{s}^{f}}\cdot p_{f},p_{\tau }\right )

这里的s_{i}^{f}=log\: w_{i}^{f}s_{max}^{f}u_{s}^{f}s_{i}^{f}的最大值和平均值,p_f是一个控制总体概率的超参数。

在GCA中联合执行拓扑和节点的数据增强。对于两个视图来说,p_ep_f是不一样的,用p_{e,1},p_{f,1}p_{e,2},p_{f,2}来表示。另外采用三种不同的节点中心性度量的GCA分别记作GCA-DE, GCA-EV和GCA-PR。GCA是为学习节点表示而设计的,并未涉及图表示的学习。注意,所有的中心性和权重度量都只依赖于原始图的拓扑和节点属性。因此,它们只需要计算一次,不会带来太大的计算负担。

三、实验

  1. 数据集

本文实验采用以下数据集:

数据集
  1. 实验

下表为上面数据集上的节点分类任务性能指标:

实验
  1. 消融实验

下面的消融实验探究了两种数据增强方式的影响:

消融实验

下图展示了不同的概率对性能的影响:

消融实验
上一篇下一篇

猜你喜欢

热点阅读