图分类预测

2022-01-20  本文已影响0人  笑傲NLP江湖

原创:梁华雄

导入

图级别的预测可以完成对整个图属性的预测,比如在生化预测任务中,可以实现对某个分子是否产生变异进行预判。在非欧几里得的数据结构中,如社交网络(facebook),生物网络(基因,分子),基础设施网络(能源,交通,互联网,通信)具有重要的意义。

1. 原理

整图预测是针对图层面的学习任务,比如判断某药物分子是否具有某种理化性质,再比如判断某社团是否具有欺诈可能,这需要我们对整个图提取它的特征表示,然后再基于此构建我们的学习任务,图的整体特征无外乎来源于三部分:1)节点特征;2)边特征;3)结构信息,基于这些信息,我们可以通过许多方式来构建图特征,DGL提供了一些简单的API,比如对各节点特征求和/求平均/pooling等,这可以方便我们构建一些基准图预测模型,下面我们利用对节点特征求平均的方式构建图特征,这可以通过dgl.mean_nodes这个API很方便的实现,它相当于做了如下计算:
h_g=\frac{1}{|V|}\sum_{v\in{V}}{h_v}
h_v表示节点v的特征,然后基于h_g特征向量,构建我们预测模型。

2. 实现

利用dgl自带的MiniGCDataset数据集,它包括如下的8种类别的图结构,数据集包含8种不同类型的图形。

2.1 数据集

#1.导入数据
import dgl
import torch
from dgl.data import MiniGCDataset
import matplotlib.pyplot as plt
import networkx as nx
#这里,随机构造了80个图,每个图是少10条边,最多30条边
dataset = MiniGCDataset(80, 10, 20)
graph, label = dataset[1]
#绘制图像
%matplotlib inline
fig, ax = plt.subplots()
nx.draw(graph.to_networkx(), ax=ax)
ax.set_title('Class: {:d}'.format(label))
plt.show()

2.2 分类器

构建分类器,这里采用两层,最后接一个线性分类器来实现图的分类,代码如下:

#2.定义模型
from dgl.nn.pytorch import GraphConv
import torch.nn.functional as F
from torch import nn
class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, n_classes)#线性分类器
    def forward(self, g):
        # 以节点度作为初始节点特征。对于无向图,入度与外度相同。
        h = g.in_degrees().view(-1, 1).float()
        # 执行图形卷积和激活函数
        h = F.relu(self.conv1(g,h))
        h = F.relu(self.conv2(g,h))
        g.ndata['h'] = h
        # 通过对所有节点表示求平均来计算图形表示。
        hg = dgl.mean_nodes(g, 'h')
        return self.classify(hg)

2.3 训练

开始训练,训练500次。

# 训练集/测试集
trainset = MiniGCDataset(320, 10, 20)
testset = MiniGCDataset(80, 10, 20)

#batch训练
data_loader = DataLoader(trainset, batch_size=32, shuffle=True,
                         collate_fn=collate)

# 构建模型
model = Classifier(1, 256, trainset.num_classes)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
epoch_losses = []
for epoch in range(500):
    epoch_loss = 0
    for i, (bg, label) in enumerate(data_loader):
        prediction = model(bg)
        loss = loss_func(prediction, label.long())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (i + 1)
#     print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)
plt.plot(epoch_losses)
plt.legend(["loss"])

2.4 测试

#4.测试
model.eval()
test_X, test_Y = map(list, zip(*testset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
pred_Y = torch.max(model(test_bg), 1)[1].view(-1, 1)
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_Y == pred_Y.float()).sum().item() / len(test_Y) * 100))

2.5 混淆矩阵

#5.查看混淆矩阵
from sklearn.metrics import confusion_matrix
confusion_matrix(test_Y, pred_Y)

通过混淆矩阵,可以看到在calss1和class5这两类区分不明显,但是其他的类都基本都正确分类出来了。

总结

图的分类:对于整个图结构来说,我们可以对图分类,图分类又称为图的同构问题,基本思路是将图中节点的特征聚合起来作为图的特征,再进行分类。

上一篇 下一篇

猜你喜欢

热点阅读