人工智能

InfoGraph:基于互信息最大化的无监督和半监督图表示学习

2022-04-10  本文已影响0人  酷酷的群

论文标题:InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization
论文链接:https://arxiv.org/abs/1908.01000
论文来源:ICLR 2020
代码地址:https://github.com/fanyun-sun/InfoGraph

之前的相关博客:
MINE:随机变量互信息的估计方法
Deep InfoMax:基于互信息最大化的表示学习

一、概述

本文提出的InfoGraph是一种基于互信息最大化的图对比学习方法,与Deep Graph Infomax(DIM)相比,虽然都是基于互信息最大化的方法,InfoGraph更加侧重于图的表示学习,而DIM偏重于节点的表示学习。

过去的图相关的任务大多是监督学习任务,而图数据的标注通常是困难的和繁琐的,而对于大量的无标注数据未能有效利用。本文提出的InfoGraph侧重于图的无监督表示学习,另外InfoGraph*是在InfoGraph基础上拓展的半监督学习方法。InfoGraph*应用一个类似于Mean-Teacher方法的student-teacher框架,通过让一个encoder学习另一个encoder(最大化两者的互信息)从而在半监督任务上产生了较好的效果。

二、方法

  1. 问题定义

给定一个图集合\mathbb{G}=\left \{G_{1},G_{2},\cdots \right \}以及一个正整数\delta(也就是embedding size),我们的目标是学习每个图G_{i}\in \mathbb{G}\delta维表示,我们将G_i的节点数记作|G_i|,将所有图的表示矩阵记作\Phi \in \mathbb{R}^{|G|\times \delta }

给定标注图集合\mathbb{G}^{L}=\left \{G_{1},\cdots ,G_{|\mathbb{G}^{L}|}\right \}以及对应的输出\left \{o_{1},\cdots ,o_{|\mathbb{G}^{L}|}\right \},以及一个未标注图集合\mathbb{G}^{U}=\left \{G_{|\mathbb{G}^{L}|+1},\cdots ,G_{|\mathbb{G}^{L}|+|\mathbb{G}^{U}|}\right \},我们的目标是学习一个模型能够在未见图上进行预测。注意在大多数情况下|\mathbb{G}^{U}|\gg |\mathbb{G}^{L}|

  1. InfoGraph

首先采用一个encoder获得图的节点表示(patch表示),然后使用readout函数来聚合获得的节点表示以得到图的表示。本文采用的encoder通过聚合邻居节点的特征来获得节点的表示:

h_{v}^{(k)}=COMBINE^{(k)}\left (h_{v}^{(k-1)},AGGREGATE\left (\left \{(h_{v}^{(k-1)},h_{u}^{(k-1)},e_{uv}):u\in N(v)\right \}\right )\right )

这里h_{v}^{(k)}是节点v在第k层的节点表示,e_{uv}uv之间的边的特征向量,N(v)是节点v的邻居节点集合。h_{v}^{(0)}使用节点原生特征来初始化。本文采用的encoder为Graph Isomorphism Network (GIN)。Readout函数可以是简单的平均或者也可以采用一些更复杂的图池化函数。

我们通过最大化图表示和patch表示之间的互信息来获得图的表示。通过这样的方式,图表示能够学习编码数据结构中共享的信息。假设给定一个图集合\mathrm{G}:=\left \{G_{j}\in \mathbb{G}_{j=1}^{N}\right \},这些图服从一个经验分布\mathbb{P}。对于一个K层encoder神经网络,其参数记作\phi,在通过encoder的第k层后,输入图的节点特征被编码成一系列patch表示特征向量\left \{h_{i}^{(k)}\right \}_{i=1}^{N},然后我们将encoder的每层表示聚合成一个图的表示向量:

h_{\phi }^{i}=\mathrm{CONCAT}(\left \{h_{i}^{(k)}\right \}_{k=1}^{K})\\ H_{\phi }(G)=\mathrm{READOUT}(\left \{h_{\phi }^{i}\right \}_{i=1}^{N})

这里的H_{\phi }(G)就是图的表示。接着我们定义图的表示与patch表示对之间的互信息,学习图的表示的过程就是最大化这个互信息的过程:

\hat{\phi },\hat{\psi }=\underset{\phi ,\psi }{argmax}\sum _{G\in \mathrm{G}}\frac{1}{|G|}\sum _{u\in G}I_{\phi ,\psi }(\vec{h}_{\phi }^{u};H_{\phi }(G))

I_{\phi ,\psi }是一个互信息estimator,包含一个由\psi参数化的discriminatorT_{\psi }。我们使用Jensen-Shannon互信息估计:

I_{\phi ,\psi }(h_{\phi }^{i}(G);H_{\phi }(G)):=E_{\mathbb{P}}[-sp(-T_{\phi ,\psi }(h_{\phi }^{i}(x),H_{\phi }(x)))]-E_{\mathbb{P}\times \mathbb{\tilde{P}}}[sp(T_{\phi ,\psi }(h_{\phi }^{i}(x^{'}),H_{\phi }(x)))]

这里x是从\mathbb{P}中采样的输入样本,x^{'}是从\mathbb{\tilde{P}}=\mathbb{P}中采样的负样本,sp(z)=log(1+e^{z})是softplus函数。在实践中,我们使用一个batch内的所有的图表示与patch表示的两两组合来作为负样本。

由于H_{\phi }(G)被要求与当前图的所有patch表示的互信息最大,那么H_{\phi }(G)就倾向于编码图中共享的信息。整个算法如下图所示:

InfoGraph算法
  1. 半监督InfoGraph

一个比较直接的将无监督的方法拓展成半监督的方式是将无监督的损失作为有监督目标的正则项,如下:

L_{total}=\sum_{i=1}^{|\mathbb{G}^{L}|}L_{supervised}(y_{\phi }(G_{i}),o_{i})+\lambda \sum_{j=1}^{|\mathbb{G}^{L}|+|\mathbb{G}^{U}|}L_{unsupervised}(h_{\phi }(G_{j});H_{\phi }(G_{j}))

L_{unsupervised}(h_{\phi }(G_{j});H_{\phi }(G_{j}))是上面InfoGraph的损失,在这里应用在所有的有标注和无标注数据上。\lambda是超参数。

这样的设计直观看来,在学习预测相应的监督标签时,模型将受益于从大量无标签数据中学习良好的表示。然而,监督任务和非监督任务可能偏好不同的信息或不同的语义空间。简单地使用同一个encoder来结合这两个损失可能会造成“负迁移”(negative transfer)。本文对此改进的方法就是采用两个encoder,也就是一个监督encoder和一个无监督encoder。为了将学习到的表示从无监督encoder转移到监督encoder,我们定义了一个损失项,它鼓励两个encoder学习到的表示在所有层上都具有高的互信息。使用\varphi代表另一个K层encoder的参数,两个encoder是完全一样的。损失函数如下:

L_{total}=\sum_{i=1}^{|\mathbb{G}^{L}|}L_{supervised}(y_{\phi }(G_{i}),o_{i})+\sum_{j=1}^{|\mathbb{G}^{L}|+|\mathbb{G}^{U}|}L_{unsupervised}(h_{\varphi }(G_{j});H_{\varphi }(G_{j}))\\ -\lambda \sum_{j=1}^{|\mathbb{G}^{L}|+|\mathbb{G}^{U}|}\frac{1}{|G_{j}|}\sum_{k=1}^{K}I(H_{\phi }^{k}(G_{j});H_{\varphi }^{k}(G_{j}))

这里的H_{\phi }^{k}(G),H_{\varphi }^{k}(G)指encoder的第k层的图的全局表示。整个过程如下图所示:

InfoGraph*

这种半监督的InfoGraph方法称为InfoGraph*。注意,InfoGraph*可以看作是student-teacher框架的一个特殊实例。然而,与最近的student-teacher半监督学习方法不同,这些方法使得学生模型的预测与教师模型相似,而InfoGraph*通过在表示的各层上的互信息最大化来实现知识从教师模型向学生模型的转移。

三、实验

本文在MUTAG, PTC, REDDIT-BINARY, REDDIT-MULTI-5K, IMDB-BINARY和IMDB-MULTI一共6个数据集上进行分类任务实验,在QM9数据集上进行半监督实验。实验结果如下:

实验 实验
上一篇下一篇

猜你喜欢

热点阅读