图分类预测
2022-01-20 本文已影响0人
笑傲NLP江湖
原创:梁华雄
导入
图级别的预测可以完成对整个图属性的预测,比如在生化预测任务中,可以实现对某个分子是否产生变异进行预判。在非欧几里得的数据结构中,如社交网络(facebook),生物网络(基因,分子),基础设施网络(能源,交通,互联网,通信)具有重要的意义。
1. 原理
整图预测是针对图层面的学习任务,比如判断某药物分子是否具有某种理化性质,再比如判断某社团是否具有欺诈可能,这需要我们对整个图提取它的特征表示,然后再基于此构建我们的学习任务,图的整体特征无外乎来源于三部分:1)节点特征;2)边特征;3)结构信息,基于这些信息,我们可以通过许多方式来构建图特征,DGL提供了一些简单的API,比如对各节点特征求和/求平均/pooling等,这可以方便我们构建一些基准图预测模型,下面我们利用对节点特征求平均的方式构建图特征,这可以通过dgl.mean_nodes这个API很方便的实现,它相当于做了如下计算:
表示节点
的特征,然后基于
特征向量,构建我们预测模型。
2. 实现
利用dgl自带的MiniGCDataset数据集,它包括如下的8种类别的图结构,数据集包含8种不同类型的图形。
- 第0类:循环图
- 第1类:星形图
- 第2类:车轮图
- 第3类:棒棒糖图
- 第4类:超立方体图
- 第5类:网格图
- 第6类:集团图
- 第7类:圆形梯形图
2.1 数据集
#1.导入数据
import dgl
import torch
from dgl.data import MiniGCDataset
import matplotlib.pyplot as plt
import networkx as nx
#这里,随机构造了80个图,每个图是少10条边,最多30条边
dataset = MiniGCDataset(80, 10, 20)
graph, label = dataset[1]
#绘制图像
%matplotlib inline
fig, ax = plt.subplots()
nx.draw(graph.to_networkx(), ax=ax)
ax.set_title('Class: {:d}'.format(label))
plt.show()
2.2 分类器
构建分类器,这里采用两层,最后接一个线性分类器来实现图的分类,代码如下:
#2.定义模型
from dgl.nn.pytorch import GraphConv
import torch.nn.functional as F
from torch import nn
class Classifier(nn.Module):
def __init__(self, in_dim, hidden_dim, n_classes):
super(Classifier, self).__init__()
self.conv1 = GraphConv(in_dim, hidden_dim)
self.conv2 = GraphConv(hidden_dim, hidden_dim)
self.classify = nn.Linear(hidden_dim, n_classes)#线性分类器
def forward(self, g):
# 以节点度作为初始节点特征。对于无向图,入度与外度相同。
h = g.in_degrees().view(-1, 1).float()
# 执行图形卷积和激活函数
h = F.relu(self.conv1(g,h))
h = F.relu(self.conv2(g,h))
g.ndata['h'] = h
# 通过对所有节点表示求平均来计算图形表示。
hg = dgl.mean_nodes(g, 'h')
return self.classify(hg)
2.3 训练
开始训练,训练500次。
# 训练集/测试集
trainset = MiniGCDataset(320, 10, 20)
testset = MiniGCDataset(80, 10, 20)
#batch训练
data_loader = DataLoader(trainset, batch_size=32, shuffle=True,
collate_fn=collate)
# 构建模型
model = Classifier(1, 256, trainset.num_classes)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
epoch_losses = []
for epoch in range(500):
epoch_loss = 0
for i, (bg, label) in enumerate(data_loader):
prediction = model(bg)
loss = loss_func(prediction, label.long())
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
epoch_loss /= (i + 1)
# print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
epoch_losses.append(epoch_loss)
plt.plot(epoch_losses)
plt.legend(["loss"])
2.4 测试
#4.测试
model.eval()
test_X, test_Y = map(list, zip(*testset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
pred_Y = torch.max(model(test_bg), 1)[1].view(-1, 1)
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
(test_Y == pred_Y.float()).sum().item() / len(test_Y) * 100))
2.5 混淆矩阵
#5.查看混淆矩阵
from sklearn.metrics import confusion_matrix
confusion_matrix(test_Y, pred_Y)
通过混淆矩阵,可以看到在calss1和class5这两类区分不明显,但是其他的类都基本都正确分类出来了。
总结
图的分类:对于整个图结构来说,我们可以对图分类,图分类又称为图的同构问题,基本思路是将图中节点的特征聚合起来作为图的特征,再进行分类。