深度学习

PyG构建图对象并转换成networkx图对象

2024-01-10  本文已影响0人  马尔代夫Maldives

一、写在前面

PyG 是一款基于PyTorch 的图神经网络库,它提供了很多经典的图神经网络模型和图数据集。
在使用 PyG 框架来构建和训练图网络模型时,需要事先将图数据换成PyG定义的“图对象”
PyG 提供多种类型的图对象(在torch_geometric.data下),常用的包括:Data(同构图)和HeteroData(异构图)

无标题.png

二、基本用法(以Data对象为例)

2.1) 构建图对象

构建一张图的Data对象时,通常需要提供以下基本数据:
from torch_geometric.data import Data
Data ( x: Optional[torch.Tensor] = None,
   edge_index: Optional[torch.Tensor] = None,
   edge_attr: Optional[torch.Tensor] = None,
   y: Optional[torch.Tensor] = None,
   pos: Optional[torch.Tensor] = None,
   **kwargs)

上述信息通常需要用户提前准备好,才能构建一个Data对象,但都不是必须要提供的。一般对于一张图而言,最重要的是节点特征矩阵、边矩阵、边特征矩阵

Data 对象有点类似 Python 中的字典,属性和数据用键值对表示,因此可以用点“.”或方括号“[]”来访问、修改、增加其内部的数据,就跟字典的操作方式一样。

2.2) 图对象的方法

见‘举例1’一节的3.3)

三、举例1(简单例子)

目标:为下图创建Data对象:


原图.png
import torch
from torch_geometric.data import Data
import networkx as nx
from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt

3.1)图的原始图数据准备:

首先将上图中的图(Graph)转化成对应的tensor(非常重要,决定了后面的图对象是否能正确构建)。

# 节点特征矩阵(一行对应一个节点的特征,每个节点有3个特征)
>>my_node_features = torch.tensor([[-1, -1, -1], 
                                 [-2, -2, -2],
                                 [-3, -3, -3],
                                 [-4, -4, -4]],dtype=torch.float)

# 边的节点对,共有7条边(四个节点:0、1、2、3),必须用7组节点对来表示
>>my_edge_index = torch.tensor([[0, 1, 2, 1, 3, 2, 3],
                              [1, 2, 1, 3, 1, 3, 2]], dtype=torch.long)

# 边特征矩阵(一行对应一条边的特征,每条边有4个特征)
>>my_edge_attr = torch.tensor([[11, 11, 11, 11],
                             [22, 22, 22, 22],
                             [33, 33, 33, 33],
                             [44, 44, 44, 44],
                             [55, 55, 55, 55],
                             [66, 66, 66, 66],
                             [77, 77, 77, 77]], dtype=torch.float)

# 边权重,共有7个边权重,一条边一个
>>my_edge_weight = torch.tensor([1, 2, 3, 4, 5, 6, 7], dtype=torch.float)

3.2)根据图的原始数据构建PyG图对象(Data对象):

>>pyg_G = Data(x=my_node_features, 
             edge_index=my_edge_index, 
             edge_attr=my_edge_attr, 
             edge_weight=my_edge_weight)
>>print(pyg_G)
输出:
Data(x=[4, 3], edge_index=[2, 7], edge_attr=[7, 4], edge_weight=[7])
PyG对象输出解读.jpg
对PyG对象输出信息的解读很重要(特别是对于无法图像化的大图)!

3.3)图对象(Data对象)提供的几种常用方法(其他方法使用‘dir(图对象)’获取):

>>pyg_G.node_attrs()
['x']
>>pyg_G.x
tensor([[-1., -1., -1.],
        [-2., -2., -2.],
        [-3., -3., -3.],
        [-4., -4., -4.]])
>>pyg_G.edge_index
tensor([[0, 1, 2, 1, 3, 2, 3],
        [1, 2, 1, 3, 1, 3, 2]])
pyg_G.edge_attrs()
['edge_weight', 'edge_attr', 'edge_index']
>>pyg_G.edge_weight
tensor([1., 2., 3., 4., 5., 6., 7.])
>>pyg_G.edge_attr
tensor([[11., 11., 11., 11.],
        [22., 22., 22., 22.],
        [33., 33., 33., 33.],
        [44., 44., 44., 44.],
        [55., 55., 55., 55.],
        [66., 66., 66., 66.],
        [77., 77., 77., 77.]])
>>pyg_G.node_stores
[{'x': tensor([[-1., -1., -1.],
         [-2., -2., -2.],
         [-3., -3., -3.],
         [-4., -4., -4.]]), 'edge_index': tensor([[0, 1, 2, 1, 3, 2, 3],
         [1, 2, 1, 3, 1, 3, 2]]), 'edge_attr': tensor([[11., 11., 11., 11.],
         [22., 22., 22., 22.],
         [33., 33., 33., 33.],
         [44., 44., 44., 44.],
         [55., 55., 55., 55.],
         [66., 66., 66., 66.],
         [77., 77., 77., 77.]]), 'edge_weight': tensor([1., 2., 3., 4., 5., 6., 7.])}]

3.4)PyG图对象与networkx图对象的转换(检查我们创建的PyG对象是否与原图一致)

https://blog.csdn.net/zzy_NIC/article/details/127996911
https://zhuanlan.zhihu.com/p/92482339
PyG主要用于图网络计算,本身没有可视化功能。可利用PyG的to_networkx()方法将PyG同构图对象转化成networkx对象,然后可视化。
to_networkx(
   data: PyG的Data或HeteroData对象,
   node_attrs: 节点属性名(可迭代str对象,默认None),
   edge_attrs: 边属性名(可迭代str对象,默认None),
   graph_attrs: 图属性名(可迭代str对象,默认None),
   to_undirected: 转换成无向图还是有向图(True/False,默认False),
   remove_self_loops: 是否将图中的loop移除(True/False,默认False),
)

■■Case1:转换时,不指定 node_attrs、edge_attrs、graph_attrs参数。
从输出结果来看,这种情况to_networkx()只会把PyG对象的节点(nodes)和边(edges)转换到networkx对象中,其他属性信息不会包含(下图中全是空{ })。其次,从输出的节点名、边的节点对以及图像来看,与最前面的‘原图’是相同的,说明我们构建的PyG是对的。

# Case1
>>nx_G = to_networkx(data=pyg_G, to_undirected=False)  # 将PyG的Data对象转化成networkx的数据对象

>>print(f'节点名:{nx_G.nodes}')
>>print(f'边的节点对:{nx_G.edges}')
>>print('每个节点的属性:')
# print(nx_G.nodes(data=True))
>>for node in nx_G.nodes(data=True):
    print(node)
>>print('每条边的属性:')
# print(nx_G.edges(data=True))
>>for edge in nx_G.edges(data=True):
    print(edge)

# 画图
>>pos = nx.spring_layout(nx_G)  # 迭代计算‘可视化图片’上每个节点的坐标
>>nx.draw(nx_G, pos, node_size=800, with_labels=True, font_size=20)  # 绘图
>>plt.show()

输出:如下图所示
图片1.png

■■Case2:转换时,指定 node_attrs、edge_attrs、graph_attrs参数。
这种情况,首先得查看原PyG对象有哪些属性:

>>print(pyg_G.node_attrs())
>>print(pyg_G.edge_attrs())
输出:
['x']
['edge_weight', 'edge_attr', 'edge_index']

可见,该PyG对象有节点属性有['x'],边属性有['edge_weight', 'edge_attr', 'edge_index'],
于是可以在to_networkx()转换时进行指定(特别注意:'edge_index'这个属性不能写在to_networkx()的edge_attrs变量中,否则出错),见下面代码:

# Case2
>>nx_G = to_networkx(data=pyg_G, 
                   node_attrs=['x'],
                   edge_attrs=['edge_weight', 'edge_attr'],
                   to_undirected=True)  # 将PyG的Data对象转化成networkx的数据对象

>>print(f'节点名:{nx_G.nodes}')
>>print(f'边的节点对:{nx_G.edges}')
>>print('每个节点的属性:')
# print(nx_G.nodes(data=True))
>>for node in nx_G.nodes(data=True):
    print(node)
>>print('每条边的属性:')
# print(nx_G.edges(data=True))
>>for edge in nx_G.edges(data=True):
    print(edge)

# 画图
>>pos = nx.spring_layout(nx_G)  # 迭代计算‘可视化图片’上每个节点的坐标
>>nx.draw(nx_G, pos, node_size=400, with_labels=True)  # 绘图
>>plt.show()
图片2.png

从上图的输出结果看,已经把PyG对象的节点和边的各种属性同时转化成networkx对象的属性了。

四、举例2(PyG对象节点、边、节点特征、边特征之间的对应关系剖析)

import torch
from torch_geometric.data import Data
import networkx as nx
from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt

4.1)原始图数据准备

与前面不同的是,此例中事先并不知道图的结构,只有数据。
而且注意:
my_node_features的shape=[5,3],即节点序号为:0、1、2、3、4;
但边的节点对my_edge_index 指定的节点为:10、11、12、13。

# 节点特征矩阵(一行对应一个节点的特征,每个节点有3个特征)
>>my_node_features = torch.tensor([[-1, -1, -1], 
                                 [-2, -2, -2],
                                 [-3, -3, -3],
                                 [-4, -4, -4],
                                 [-5, -5, -5]],
                                dtype=torch.float)

# 边矩阵(这里共有7条边,必须用7组节点对来表示,节点对的前后位置可以任意调换,对结果没有影响)
>>my_edge_index = torch.tensor([[10, 11, 12, 11, 13, 13, 12],
                              [11, 12, 11, 13, 11, 12, 13]], dtype=torch.long)

# 边特征矩阵(一行对应一条边的特征,每条边有4个特征)
>>my_edge_attr = torch.tensor([[11, 11, 11, 11],
                             [22, 22, 22, 22],
                             [33, 33, 33, 33],
                             [44, 44, 44, 44],
                             [55, 55, 55, 55],
                             [66, 66, 66, 66],
                             [77, 77, 77, 77]], dtype=torch.float)

# 边权重,共设置了7个边权重
>>my_edge_weight = torch.tensor([1, 2, 3, 4, 5, 6, 7], dtype=torch.float)

4.2)根据原始数据构建PyG图对象

>>pyg_G = Data(x=my_node_features, 
             edge_index=my_edge_index, 
             edge_attr=my_edge_attr, 
             edge_weight=my_edge_weight)
>>print(pyg_G)
输出:
Data(x=[5, 3], edge_index=[2, 7], edge_attr=[7, 4], edge_weight=[7])

从PyG对象的输出结果看,该图有5个节点,每个节点3个特征;共有7条边,每条边4个特征,1个权重。

输出节点和边的属性名列表:
>>print(pyg_G.node_attrs())
>>print(pyg_G.edge_attrs())
输出:
['x']
['edge_index', 'edge_weight', 'edge_attr']

4.3)将PyG对象转换成networkx对象,并成图

>>nx_G = to_networkx(data=pyg_G, 
                   node_attrs=['x'],
                   edge_attrs=['edge_weight', 'edge_attr'],
                   to_undirected=False)  # 将PyG的Data对象转化成networkx的数据对象

>>print(f'节点名:{nx_G.nodes}')
>>print(f'边的节点对:{nx_G.edges}')
>>print('每个节点的属性:')
# print(nx_G.nodes(data=True))
>>for node in nx_G.nodes(data=True):
    print(node)
>>print('每条边的属性:')
# print(nx_G.edges(data=True))
>>for edge in nx_G.edges(data=True):
    print(edge)

# 画图
>>pos = nx.circular_layout(nx_G)  # 迭代计算‘可视化图片’上每个节点的坐标
>>nx.draw(nx_G, pos, node_size=800, with_labels=True, font_size=20)  # 绘图
>>plt.show()

Case1:参数to_undirected=False,即有向图
从输出结果的节点名来看,该图共有9个节点,前面的[0,1,2,3,4]五个节点(注意,代码中我们并没有指定这些节点名)是to_networkx()根据节点特征矩阵my_node_features的行数按0,1,2……顺序自动分配的(这是PyG固定的);后面四个节点[10,11,12,13]是to_networkx()根据用户给的边的节点对矩阵my_edge_index中自动抽取并生成的
★★可见,在利用to_networkx()将PyG对象转换成networkx对象时,to_networkx会自动补充一些节点,比如这里的[0,1,2,3,4],我们将其称为冗余节点!可以写额外的代码来将这些冗余节点删除,见子图抽取的‘2.2.5 将冗余节点从子图的networkx图对象中删除’

关于边的特征和权重,PyG会自动将边特征矩阵my_edge_attr的
第1行作为第1条边【这里是(10,11)】的特征;
第2行作为第2条边【这里是(11,12)】的特征;
第3行作为第3条边【这里是(12,11)】的特征;
……
同理,PyG会自动将边权重向量my_edge_weight的
第1个值作为第1条边【这里是(10,11)】的权重;
第2个值作为第2条边【这里是(11,12)】的权重;
第3个值作为第3条边【这里是(12,11)】的权重;
……

特别注意:边特征矩阵(my_edge_attr)的行数、边权重向量(my_edge_weight)的元素个数都必须和边节点对矩阵(my_edge_index )的列数相同,否则结果会出错

14029140-5305108e9eb8121b.jpg

Case2:参数to_undirected=True,即无向图
Case2除了边有所变化以外,其他都与Cas1一样。
Case2主要为了说明to_networkx()这个函数的参数to_undirected=False/True(有向图和无向图)的区别。
Cas1是有向图,根据给定的节点对矩阵my_edge_index从起点到终点画图即可,这个没啥疑问。
Cas2是无向图:

新建 Microsoft Visio 绘图.jpg

参考:
https://zhuanlan.zhihu.com/p/599104296
https://blog.csdn.net/ARPOSPF/article/details/128398393

上一篇 下一篇

猜你喜欢

热点阅读