Pytorch Geometric中的图神经网络GAT是如何实现

2021-09-02  本文已影响0人  四碗饭儿

最近在使用Pytorch Geometric, 这个包收集了最新的图神经网络的Pytorch实现。这篇文章想研究下它是怎么实现GAT(Graph Attention Network)。在PyG中实现图神经网络,主要依靠MessagePassing这个类。在继承或使用MessagePassing类时,你可以指明使用哪一种消息合并方式

MessagePassing(aggr="add", flow="src_to_tgt", node_dim=-2)

MessagePassing自带propagate函数,一般你只需在forwad里调用一下就好了

propagate(edge_index, size=None, **kwargs)# 输入边和其他必要数据,然后构造消息,更新节点的表示

message函数一般是需要你自定义的

message(...)#对于图中的每一条边$(j,i)$创建一个消息,传送个节点$i$,在pytroch geometric的代码库中,通常i指central node,j指neighboring node

Pytorch Geometric中的GAT实现源码在这里。我这里写了个精简版,方便阅读理解,代码中添加了相关注释。

class GATConv(MessagePassing):
def __init__(self):
    # 超参
    self.in_channels = in_channels # 节点特征的维度
    self.out_channels = out_channels # 每个attention head的输出维度
    self.heads = heads
    self.dropout = dropout
    # 模型参数
    self.lin = Linear(in_channels, heads * out_channels, False)
    self.att_src = Parameter(torch.Tensor(1, heads, out_channels)) # 有点像一个Mask

def forward(self, x, edge_index, size, return_attention_weights):
    # MLP
    H, C = self.heads, self.out_channels
    x_src = x_dst = self.lin(x).view(-1, H, C) # MLP输出整理成H个head的格式
    # 计算H个head的attention
    alpha_src = (x_src * self.att_src).sum(dim=-1)
    # 添加`self_loop`到`edge_index`
    edge_index, _  = add_self_loops(edge_index, num_nodes=num_nodes)
    # 消息传播和更新
    out = self.propagate(edge_index, x=x, alpha=alpha, size=size)
    # 拼接和输出
    out = out.view(-1, self.heads * self.out_channels)
    return out

def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i)
    # 把两个节点的attention weight 相加
    alpha =  alpha_j + alpha_i
    # 经过一个非线性
    alpha = F.leaky_relu(alpha, self.negative_slope)
    # 经过一个Softmax    
    alpha = softmax(alpha, index, ptr, size_i)
    self._alpha = alpha  # Save for later use.
    # 经过一个dropout
     alpha = F.dropout(alpha, p=self.dropout, training=self.training)
     # attention weight 乘以node feature
     return x_j * alpha.unsqueeze(-1)

我不太喜欢的点在于将GAT的实现拆到每条边上, 特别是attention weight,完整写起来其实是个N \times N的矩阵,但是硬要写成message的话,就只能从矩阵中取元素了。

上一篇下一篇

猜你喜欢

热点阅读