一文详解图神经网络(二)
《The Graph Neural Network Model》
GNN模型基于信息传播机制,每⼀个节点通过相互交换信息来更新⾃⼰的节点状态,直到达到某⼀个稳定值,GNN的输出就是在每个节点处,根据当前节点状态分别计算输出
5.1.1 任务定义
图领域的应用主要可以分为两种类型:专注于图的应用(graph-focused)
和专注于节点的应用(node-focused)
。对于graph-focused的应用,函数和具体的节点无关,(即
),训练时,在一个图的数据集中进行分类或回归。对于node-focused的应用,
函数依赖于具体的节点
,即
- 图
表⽰为
,其中
表⽰节点集,
表⽰边集
-
表示节点
的邻居节点集合
-
表示以
节点为顶点的所有边集合
-
表示节点
的特征向量
-
表示边
的特征向量
-
表示所有特征向量叠在⼀起的向量
在一个图-节点对的集合,
表示图的集合,
表示节点集合,图领域问题可以表示成一个有如下数据集的监督学习框架:
其中,表示集合
中的第
个节点,
表示节点
的期望目标(即标签)。节点
的状态用
表示,该节点的输出用
表示,
为
local transition function
,为
local output function
,那么和
的更新方式如下:
其中,分别表示节点
的特征向量、与节点
相连的边的特征向量、节点
邻居节点的状态向量、节点
邻居节点的特征向量。
分别为所有的状态、所有的输出、所有的特征向量、所有节点的特征向量的叠加起来的向量,那么上面函数可以写成如下形式:
其中,为
global transition function
,为
global output function
,分别是和
的叠加形式
根据Banach的不动点理论,假设是一个压缩映射函数,那么式子有唯一不动点解,而且可以通过迭代方式逼近该不动点
其中,表示
在第
个迭代时刻的值,对于任意初值,迭代的误差是以指数速度减小的,使用迭代的形式写出状态和输出的更新表达式为:
5.1.2 训练策略
GNN的学习就是估计参数,使得函数
能够近似估计训练集
其中,表示在图
中监督学习的节点,对于graph-focused的任务,需要增加一个特殊的节点,该节点用来作为目标节点,这样,
graph-focused
任务和node-focused
任务都能统一到节点预测任务上,学习目标可以是最小化如下二次损失函数
优化算法基于随机梯度下降的策略,优化步骤按照如下几步进行:
- 按照迭代方程迭代
次得到
,此时接近不动点解:
- 计算参数权重的梯度
- 使用该梯度来更新权重
这里假设函数
是压缩映射函数,保证最终能够收敛到不动点。另外,这里的梯度的计算使用
backpropagation-through-time algorithm
理论1(可微性):令
和
分别是
global transition function
和global output function
,如果和
对于
和
是连续可微的,那么
对
也是连续可微的
理论2(反向传播):令
和
分别是
global transition function
和global output function
,如果和
对于
和
是连续可微的。令
定义为:
那么,序列收敛到一个向量,
,并且收敛速度为指数级收敛以及与初值
无关,另外,还存在:
其中,是GNN的稳定状态
5.1.3 Transition和Output函数
在GNN中,函数不需要满足特定的约束,直接使用多层前馈神经网络,对于函数
,则需要着重考虑,因为
需要满足压缩映射的条件,而且与不动点计算相关。下面提出两种神经网络和不同的策略来满足这些需求
- Linear(nonpositional) GNN:
对于节点n nn状态的计算,将改成如下形式
相当于是对节点的每一个邻居节点使用
,并将得到的值求和来作为节点
的状态,由此,对上式中的函数
按照如下方式实现:
其中,向量,矩阵
定义为两个前向神经网络的输出。更确切地说,令产生矩阵
的网络为transition network,产生向量
的网络为forcing network
-
transition network表示为
-
forcing network表示为
-
由此,可以定义
和
其中,,
,
表示将
维的向量整理(reshape)成
的矩阵,也就是说,将transition network的输出整理成方形矩阵,然后乘以一个系数就得到
,
就是forcing network的输出
在这里,假定,这个可以通过设定transition function的激活函数来满足,比如设定激活函数为
tanh()
。在这种情况下,,
和
分别是
的块矩阵形式和
的堆叠形式,可得:
该式表示对于任意的参数
是一个压缩映射,矩阵
的
1-norm
定义为:
- Nonelinear(nonpositional) GNN:
在这个结构中,通过多层前馈网络实现,但是,并不是所有的参数
都会被使用,因为同样需要保证
是一个压缩映射函数,这个可以通过惩罚项来实现
其中,惩罚项在
时为
,在
时为0,参数
定义为希望的
的压缩系数
5.1.4 代码实现
class GNN(nn.Module):
def __init__(self, config, state_net=None, out_net=None):
super(GNN, self).__init__()
self.config = config
# hyperparameters and general properties
self.convergence_threshold = config.convergence_threshold
self.max_iterations = config.max_iterations
self.n_nodes = config.n_nodes
self.state_dim = config.state_dim
self.label_dim = config.label_dim
self.output_dim = config.output_dim
self.state_transition_hidden_dims = config.state_transition_hidden_dims
self.output_function_hidden_dims = config.output_function_hidden_dims
# node state initialization
self.node_state = torch.zeros(*[self.n_nodes, self.state_dim]).to(self.config.device) # (n,d_n)
self.converged_states = torch.zeros(*[self.n_nodes, self.state_dim]).to(self.config.device)
# state and output transition functions
if state_net is None:
self.state_transition_function = StateTransition(self.state_dim, self.label_dim,
mlp_hidden_dim=self.state_transition_hidden_dims,
activation_function=config.activation)
else:
self.state_transition_function = state_net
if out_net is None:
self.output_function = MLP(self.state_dim, self.output_function_hidden_dims, self.output_dim)
else:
self.output_function = out_net
self.graph_based = self.config.graph_based
def reset_parameters(self):
self.state_transition_function.mlp.init()
self.output_function.init()
def forward(self,
edges,
agg_matrix,
node_labels,
node_states=None,
graph_agg=None
):
n_iterations = 0
# convergence loop
# state initialization
node_states = self.node_state if node_states is None else node_states
while n_iterations < self.max_iterations:
new_state = self.state_transition_function(node_states, node_labels, edges, agg_matrix)
n_iterations += 1
# convergence condition
with torch.no_grad():
distance = torch.norm(input=new_state - node_states,
dim=1) # checked, they are the same (in cuda, some bug)
check_min = distance < self.convergence_threshold
node_states = new_state
if check_min.all():
break
states = node_states
self.converged_states = states
if self.graph_based:
states = torch.matmul(graph_agg, node_states)
output = self.output_function(states)
return output, n_iterations
NLP新人,欢迎大家一起交流,互相学习,共同成长~~