BatchedDGLGraph – Enable batched

2020-03-10  本文已影响0人  魏鹏飞

Class dgl.BatchedDGLGraph(graph_list,node_attrs,edge_attrs) [source]

Class for batched DGL graphs.

A BatchedDGLGraph basically merges a list of small graphs into a giant graph so that one can perform message passing and readout over a batch of graphs simultaneously.

The nodes and edges are re-indexed with a new id in the batched graph with the rule below:

item Graph 1 Graph 2 ... Graph k
raw id 0,...,N1 0,...,N2 ... ...,Nk
new_id 0,...N1 N1+1,...,N1+N2+1 ... ...,N1+...+Nk+k-1

The batched graph is read-only, i.e. one cannot add nodes and edges. A RuntimeError will be raised if one attempts.

To modify the features in BatchedDGLGraph has no effect on the original graphs. See the examples below about how to work around.

Parameters:

Examples:
Create two DGLGraph objects.

Instantiation:

import dgl
import torch as th
g1 = dgl.DGLGraph()
g1.add_nodes(2)                                # Add 2 nodes
g1.add_edge(0, 1)                              # Add edge 0 -> 1
g1.ndata['hv'] = th.tensor([[0.], [1.]])       # Initialize node features
g1.edata['he'] = th.tensor([[0.]])             # Initialize edge features
g2 = dgl.DGLGraph()
g2.add_nodes(3)                                # Add 3 nodes
g2.add_edges([0, 2], [1, 1])                   # Add edges 0 -> 1, 2 -> 1
g2.ndata['hv'] = th.tensor([[2.], [3.], [4.]]) # Initialize node features
g2.edata['he'] = th.tensor([[1.], [2.]])       # Initialize edge features

Merge two DGLGraph objects into one BatchedDGLGraph object. When merging a list of graphs, we can choose to include only a subset of the attributes.

bg = dgl.batch([g1, g2], edge_attrs=None)
bg.edata

# Results:
{}

Below one can see that the nodes are re-indexed. The edges are re-indexed in the same way.

bg.nodes(), bg.ndata['hv']

# Results:
(tensor([0, 1, 2, 3, 4]),
 tensor([[0.],
         [1.],
         [2.],
         [3.],
         [4.]]))

Property:
We can still get a brief summary of the graphs that constitute the batched graph.

bg.batch_size, bg.batch_num_nodes, bg.batch_num_edges

# Results:
(2, [2, 3], [1, 2])

Readout:
Another common demand for graph neural networks is graph readout, which is a function that takes in the node attributes and/or edge attributes for a graph and outputs a vector summarizing the information in the graph. BatchedDGLGraph also supports performing readout for a batch of graphs at once.

Below we take the built-in readout function sum_nodes() as an example, which sums over a particular kind of node attribute for each graph.

dgl.sum_nodes(bg, 'hv') # Sum the node attribute 'hv' for each graph.

# Results:
tensor([[1.],               # 0 + 1
        [9.]])              # 2 + 3 + 4

Message passing:
For message passing and related operations, BatchedDGLGraph acts exactly the same as DGLGraph.

Update Attributes:
Updating the attributes of the batched graph has no effect on the original graphs.

bg.edata['he'] = th.zeros(3, 2)
g2.edata['he']

# Results:
tensor([[1.],
        [2.]])}

Instead, we can decompose the batched graph back into a list of graphs and use them to replace the original graphs.

g1, g2 = dgl.unbatch(bg)    # returns a list of DGLGraph objects
g2.edata['he']

# Results:
tensor([[0., 0.],
        [0., 0.]])}

Merge and decompose

batch(graph_list[, node_attrs, edge_attrs]) Batch a collection of DGLGraph and return a BatchedDGLGraph object that is independent of the graph_list.
unbatch(graph) Return the list of graphs in this batch.

Query batch summary

BatchedDGLGraph.batch_size Number of graphs in this batch.
BatchedDGLGraph.batch_num_nodes Number of nodes of each graph in this batch.
BatchedDGLGraph.batch_num_edges Number of edges of each graph in this batch.

Graph Readout

sum_nodes(graph, feat[, weight]) Sums all the values of node field feat in graph, optionally multiplies the field by a scalar node field weight.
sum_edges(graph, feat[, weight]) Sums all the values of edge field feat in graph, optionally multiplies the field by a scalar edge field weight.
mean_nodes(graph, feat[, weight]) Averages all the values of node field feat in graph, optionally multiplies the field by a scalar node field weight.
mean_edges(graph, feat[, weight]) Averages all the values of edge field feat in graph, optionally multiplies the field by a scalar edge field weight.
max_nodes(graph, feat) Take elementwise maximum over all the values of node field feat in graph
max_edges(graph, feat) Take elementwise maximum over all the values of edge field feat in graph
topk_nodes(graph, feat, k[, descending, idx]) Return graph-wise top-k node features of field feat in graph ranked by keys at given index idx.
topk_edges(graph, feat, k[, descending, idx]) Return graph-wise top-k edge features of field feat in graph ranked by keys at given index idx.
softmax_nodes(graph, feat) Apply batch-wise graph-level softmax over all the values of node field feat in graph.
softmax_edges(graph, feat) Apply batch-wise graph-level softmax over all the values of edge field feat in graph.
broadcast_nodes(graph, feat_data) Broadcast feat_data to all nodes in graph, and return a tensor of node features.
broadcast_edges(graph, feat_data) Broadcast feat_data to all edges in graph, and return a tensor of edge features.

原文链接:
https://docs.dgl.ai/api/python/batch.html?highlight=sum_nodes

上一篇下一篇

猜你喜欢

热点阅读