人工智能

GraphCL:基于数据增强的图对比学习

2022-04-30  本文已影响0人  酷酷的群

论文标题:Graph Contrastive Learning with Augmentations
论文链接:https://arxiv.org/abs/2010.13902
论文来源:NeurIPS 2020

一、概述

预训练在深度模型的训练过程中相当于作为一个正则化器以避免梯度消失或爆炸。而对于GNN而言,很少有关于(自监督)预训练相关技术的研究。究其原因可能是图数据集通常规模较小,而且GNN模型通常在结构上设计较浅以避免过平滑(over-smoothing)或者信息损失。对于图数据集来说,数据的标注(比如化学和生物领域内的分子标注)是困难的,预训练的方法可以缓解这一问题,因而本文着力研究图数据集上的自监督预训练方法。

图是具有不同性质的原始数据的抽象表示,数据可能来自各个领域(比如化学分子或者社交网络),在图的上下文中存在极为丰富的信息,因而不容易设计一个能够应用在各种下游任务的通用框架。一种比较朴素的预训练方式如GAE、GraphSAGE,主要通过重构节点邻接信息来实现,这种方式是非常有限的,因为其过分强调接近性,这并不总是有效的,有时候会忽略和损伤结构信息。因此,需要一个设计良好的预训练框架来捕获图结构数据中的高度异构信息。

本文提出的GraphCL作为一种图的预训练框架,采用对比学习的方法作为基础,并且实验采用了四种不同的图数据增强方式,同时应用互信息最大化的方法来进行训练。

二、方法

  1. 图神经网络

GNN通常遵循一个迭代的邻域聚合框架,使用\mathcal{G}=\left \{\mathcal{V},\mathcal{E}\right \}来表示一个无向图,同时X\in \mathbb{R}^{|\mathcal{V}|\times N}是特征矩阵,x_{n}=X[n,:]^{T}是节点v_{n}\in \mathcal{V}N维特征向量。考虑一个K层GNNf(\cdot ),其第k层的传播过程如下:

a_{n}^{(k)}=AGGREGATION^{(k)}(\left \{h_{n^{'}}^{(k-1)}:n^{'}\in N(n)\right \})\\ h_{n}^{(k)}=COMBINE^{(k)}(h_{n}^{(k-1)},a_{n}^{(k)})

这里h_{n}^{(k)}是节点v_n在第k层的embedding向量,并且有h_{n}^{(0)}=x_{n}N(n)是节点v_n的邻域节点集合。经过K层GNN传播后,图\mathcal{G}的embedding经由节点embedding向量通过一个READOUT函数生成,然后再通过一个MLP后用于图级的下游任务(分类或回归):

f(\mathcal{G})=READOUT(\left \{h_{n}^{(k)}:v_{n}\in V,k\in K\right \})\\ z_{\mathcal{G}}=MLP(f(\mathcal{G}))

  1. 数据增强

数据增强的目的是在不影响语义标签的情况下,通过一定的转换来创建新的现实合理数据。本文主要关注图级的数据增强。给定一个由M个图组成的数据集中的图\mathcal{G}\in \left \{\mathcal{G}_{m}:m\in M\right \},其增强图\hat{\mathcal{G}}\sim q(\hat{\mathcal{G}}|\mathcal{G}),这里的q(\cdot |\mathcal{G})是增强图的分布,代表着某种先验。对于图像分类而言,旋转和裁剪的应用对人们从旋转后的图像或其局部patch中获得相同的基于分类的语义知识进行了编码,具体的举个例子,一张猫的图片经过翻转数据增强仍然是只猫,而通过增强后的图片与原图的对比学习就可以使得神经网络模型学习到如何鉴别猫的关键特征。

在图上与图像上类似,对于图数据的数据增强应该保证增强后的图数据不应该丢失对于分类或者回归任务很关键的信息。不过对于图数据来说,由于其来自多个不同的领域,因此不容易像图像那样找到统一的数据增强方式。换句话说,对于不同类别的图数据集,某些数据增强可能比其他类型更需要。本文主要关注三类图数据集:生物化学分子(例如化合物、蛋白质)、社交网络以及图片super-pixel图。另外在实验中采用了四种不同的数据增强方式,分别对应不同的四种先验:
①Node dropping:对于给定图\mathcal{G},随机丢弃确定比例的节点以及其相关的连接。这背后的先验是认为丢弃部分节点并不会影响图\mathcal{G}的语义,每个节点的丢弃概率遵循相同的均匀分布(或者其他分布)。
②Edge perturbation:通过随机增删确定比例的边来扰乱图\mathcal{G}的连接。相关的先验是图\mathcal{G}的语义对边的连接有一定的鲁棒性。同样地采用一个相同的均匀分布来增删每条边。
③Attribute masking:促使模型通过上下文(未mask的属性)来预测被mask的节点属性。其先验是缺少部分节点属性不会对模型的预测造成太大影响。
④Subgraph:通过随机游走采样图\mathcal{G}的子图。其中的先验认为图\mathcal{G}的语义可以在它的局部结构中得到很大的保留。

数据增强

四种数据增强用到的比例默认设置为0.2。

  1. GraphCL

本文提出的graph contrastive learning(GraphCL)框架利用对比学习的方法来最大化图的两个不同视图之间的一致性以学习图的表示。下图展示了GraphCL的框架:

框架

主要包括以下4个部分:
①图数据增强模块:图\mathcal{G}通过图数据增强模块来获得两个不同的视图\hat{\mathcal{G}}_{i},\hat{\mathcal{G}}_{j},满足\hat{\mathcal{G}}_{i}\sim q_{i}(\cdot |\mathcal{G}),\hat{\mathcal{G}}_{j}\sim q_{j}(\cdot |\mathcal{G}),作为正样本对。对于不同领域的图数据集,如何有策略地选择数据增强至关重要。
②GNN encoder:一个GNN encoderf(\cdot ),用于获取视图\hat{\mathcal{G}}_{i},\hat{\mathcal{G}}_{j}的表示向量h_{i},h_{j},GraphCL对GNN的结构不做任何限制。
③非线性映射:一个非线性变换g(\cdot )将表示h_{i},h_{j}映射到另一个隐空间以获得z_{i},z_{j},在GraphCL中采用一个两层MLP。
④对比损失函数:对比损失函数\mathcal{L}(\cdot )被定义用来最大化z_{i},z_{j}之间的一致性,这里采用的是normalized temperature-scaled cross entropy loss (NT-Xent)。

具体的,在GNN预训练期间,一个mini-batch内由N个图,通过数据增强可以获得2N个增强图。对比的正样本对是一个图的两个视图,负样本对是一个图与其他图组成的样本对。首先定义余弦相似度sim(z_{n,i},z_{n,j})=z_{n,i}^{T}z_{n,j}/||z_{n,i}||\: ||z_{n,j}||,然后NT-Xent损失定义为:

\ell_{n}=-log\frac{exp(sim(z_{n,i},z_{n,j})/\tau )}{\sum _{n^{'}=1,n^{'}\neq n}^{N}exp(sim(z_{n,i},z_{n^{'},j})/\tau )}

这里\tau代表温度参数。

GraphCL可以看做最大化互信息的一种方式,可以将损失函数写成下列形式:

\ell=\mathbb{E}_{\mathbb{P}_{\hat{\mathcal{G}}_{i}}}\left \{-\mathbb{E}_{\mathbb{P}_{(\hat{\mathcal{G}}_{j}|\hat{\mathcal{G}}_{i})}}T(f_{1}(\hat{\mathcal{G}}_{i}),f_{2}(\hat{\mathcal{G}}_{j}))+log(\mathbb{E}_{\mathbb{P}_{\hat{\mathcal{G}}_{j}}}e^{T(f_{1}(\hat{\mathcal{G}}_{i}),f_{2}(\hat{\mathcal{G}}_{j}))})\right \}

T相当于一个discriminator,这一损失相当于最大化h_{i}=f_{1}(\hat{\mathcal{G}}_{i}),h_{j}=f_{2}(\hat{\mathcal{G}}_{j})的互信息,在GraphCL中f_1=f_2\mathcal{G}}_{i},\mathcal{G}}_{j}由数据增强获得。

三、实验

  1. 数据集统计

本文实验采用下列数据集:

数据集
  1. 数据增强的组合和选择

下图实验探究了不同数据增强组合的影响:

实验

下列实验表明,对于不同类型的增强对,对比损失的下降速度总是比相同类型的增强对慢,这说明模型更难识别不同类型的数据增强:

实验
  1. 数据增强的类型、范围和模式

下列实验数据增强的类型、范围和模式对效果的影响:

实验 实验

有以下结论:
①Edge perturbation对社会网络数据集有益,但会伤害生物分子数据集的性能;
②在密度较大的图中应用attribute masking可以获得更好的性能;
③Node dropping和subgraph对所有数据集都有益。

  1. 与SOTA方法的对比

实验如下:

实验

实验如下:

实验

实验如下:

实验

实验如下:

实验
上一篇下一篇

猜你喜欢

热点阅读