Pytorch Geometric中的图神经网络GAT是如何实现
2021-09-02 本文已影响0人
四碗饭儿
最近在使用Pytorch Geometric, 这个包收集了最新的图神经网络的Pytorch实现。这篇文章想研究下它是怎么实现GAT(Graph Attention Network)。在PyG中实现图神经网络,主要依靠MessagePassing
这个类。在继承或使用MessagePassing类时,你可以指明使用哪一种消息合并方式
MessagePassing(aggr="add", flow="src_to_tgt", node_dim=-2)
MessagePassing自带propagate函数,一般你只需在forwad里调用一下就好了
propagate(edge_index, size=None, **kwargs)# 输入边和其他必要数据,然后构造消息,更新节点的表示
message函数一般是需要你自定义的
message(...)#对于图中的每一条边$(j,i)$创建一个消息,传送个节点$i$,在pytroch geometric的代码库中,通常i指central node,j指neighboring node
Pytorch Geometric中的GAT实现源码在这里。我这里写了个精简版,方便阅读理解,代码中添加了相关注释。
class GATConv(MessagePassing):
def __init__(self):
# 超参
self.in_channels = in_channels # 节点特征的维度
self.out_channels = out_channels # 每个attention head的输出维度
self.heads = heads
self.dropout = dropout
# 模型参数
self.lin = Linear(in_channels, heads * out_channels, False)
self.att_src = Parameter(torch.Tensor(1, heads, out_channels)) # 有点像一个Mask
def forward(self, x, edge_index, size, return_attention_weights):
# MLP
H, C = self.heads, self.out_channels
x_src = x_dst = self.lin(x).view(-1, H, C) # MLP输出整理成H个head的格式
# 计算H个head的attention
alpha_src = (x_src * self.att_src).sum(dim=-1)
# 添加`self_loop`到`edge_index`
edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
# 消息传播和更新
out = self.propagate(edge_index, x=x, alpha=alpha, size=size)
# 拼接和输出
out = out.view(-1, self.heads * self.out_channels)
return out
def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i)
# 把两个节点的attention weight 相加
alpha = alpha_j + alpha_i
# 经过一个非线性
alpha = F.leaky_relu(alpha, self.negative_slope)
# 经过一个Softmax
alpha = softmax(alpha, index, ptr, size_i)
self._alpha = alpha # Save for later use.
# 经过一个dropout
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
# attention weight 乘以node feature
return x_j * alpha.unsqueeze(-1)
我不太喜欢的点在于将GAT的实现拆到每条边上, 特别是attention weight,完整写起来其实是个的矩阵,但是硬要写成message的话,就只能从矩阵中取元素了。