pytorch repeat 解析

2020-06-21  本文已影响0人  潘旭

pytorch repeat 解析

pytorch 中 Tensor.repeat 函数,能够将一个 tensor 从不同的维度上进行重复。这个能力在 Graph Attention Networks 中,有着应用。现在来看下,repeat 的能力是如何工作的?

repeat(*sizes) → Tensor
* sizes (torch.Size or int...) – The number of times to repeat this tensor along each dimension

翻译过来:

repeat 会将Tensor 在指定的维度方向上进行重复。比如设置参数是 2, 3, 4: 表示在 0 维方向上 重复2次,1 维方向上重复 3次, 2 维方向上重复4次。 注意这里的 2, 3, 4 不是指的维度方向,而是 0:2, 1:3, 2:4 在不同的维度上重复的次数。同时,也会进行维度的扩充。

import torch

def repeat_1():

    x = torch.tensor([1, 2, 3])
    print(f"x shape: {x.size()} : {x}")
    
    print(f"在 0 维上 重复2次 ---")
    xx = x.repeat(2)
    
    print(f"xx shape: {xx.size()}, {xx}")
    
    print(f"在 0 维上重复2次, 1 维上重复3次")
    xx = x.repeat(2, 3)
    print(f"xx shape: {xx.size()}, {xx}")

repeat_1()

x shape: torch.Size([3]) : tensor([1, 2, 3])
在 0 维上 重复2次 ---
xx shape: torch.Size([6]), tensor([1, 2, 3, 1, 2, 3])
在 0 维上重复2次, 1 维上重复3次
xx shape: torch.Size([2, 9]), tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3, 1, 2, 3]])

上面演示的是 x 只有一个维度,现在演示 2 个维度的。比如在 "Graph Attention Networds", 需要计算 图的 attention, 那么需要将图上所有的节点进行两两拼接。也就是说:

Node = \{a_1, a_2, ..., a_n\}, a_i \in \mathbb{R}^{channel}

经过两两拼接后, 形成的拼接 graph, 如下:

Graph = \begin{bmatrix} a_{11} & a_{12} & ... & a_{1n}\\ a_{21} & a_{22} & ... & a_{2n}\\ ...\\ a_{n1} & a_{n2} & ... & a_{nn} \end{bmatrix}

其中 a_{ij} 表示 [a_i || a_j] 表示两个向量拼接,拼接后的维度是 2 \times channel

现在,当我们有了一个 Node 的矩阵,如何拼接出 Graph 就用到了 repeate.

没有 batch 的情况

Node \in \mathbb{R}^{N \times C}, 经过转换后 Graph \in \mathbb{R}^{N \times N \times 2C}

开始时候:

Node = \begin{bmatrix} a_1\\ a_{2}\\ \end{bmatrix}

其中 a_i \in \mathbb{R}^{c}

在下面的例子中 Node \in \mathbb{R}^{2 \times 3}

import torch

# node = [a1, a2]
node = torch.tensor([
                        [1, 2, 3],
                        [4, 5, 6]
                    ])
n = node.size(0)
c = node.size(-1)
print(f"node: {node.size()}, {node}")
node: torch.Size([2, 3]), tensor([[1, 2, 3],
        [4, 5, 6]])

现将 Node变成:

Node\_repeat\_1 = \begin{bmatrix} a_1\\ a_{2}\\ a_1\\ a_{2}\\ \end{bmatrix}

沿着 n 的方向上重复 n次, c 的方向上不便
重复完的 node_repeat_1 shape: (n*n, c) 也就是 (4, 3)

node_repeat_1 = node.repeat(n, 1)

print(f"node_repeat_1 shape: {node_repeat_1.size()}, {node_repeat_1}")
node_repeat_1 shape: torch.Size([4, 3]), tensor([[1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6]])

想要链接在一起来不够,还要产生一个 Node\_repeat\_2,

Node\_repeat\_2 = \begin{bmatrix} a_1\\ a_1\\ a_2\\ a_2\\ \end{bmatrix}

这样 Node\_repeat\_1Node\_repeat\_2 经过 concat 操作就能够得到我们需要的 graph 了。

直接做这件事,需要一点技巧,在 c 这个方向上重复 n 次,然后在做一个 view变换。

node_repeat_2_tmp = node.repeat(1, n)
print(f"node_repeat_2_tmp: {node_repeat_2_tmp.size()}, {node_repeat_2_tmp}")
node_repeat_2_tmp: torch.Size([2, 6]), tensor([[1, 2, 3, 1, 2, 3],
        [4, 5, 6, 4, 5, 6]])
node_repeat_2 = node_repeat_2_tmp.view(-1, c)
print(f"node_repeat_2:  {node_repeat_2.size()}, {node_repeat_2}")
node_repeat_2:  torch.Size([4, 3]), tensor([[1, 2, 3],
        [1, 2, 3],
        [4, 5, 6],
        [4, 5, 6]])

最后,将 node_repeat_1 与 node_repeat_2 concat 在一起就是 graph了。 注意: node_repeat_2 在前面, node_repeat_1在后面,因为 graph 第 i 行是:
graph[i] = [a_{i1}, a_{i2}, ..., a_{in}]
所以,需要 a_{i} 与其他的所有相连接,所以需要 node_repeat_2 在前面。

graph = torch.cat((node_repeat_2, node_repeat_1), dim=-1)
print(f"graph:  {graph.size()}, {graph}")
graph:  torch.Size([4, 6]), tensor([[1, 2, 3, 1, 2, 3],
        [1, 2, 3, 4, 5, 6],
        [4, 5, 6, 1, 2, 3],
        [4, 5, 6, 4, 5, 6]])
graph_pretty = graph.view(n, n, 2 * c)
print(f"graph_pretty:  {graph_pretty.size()}, {graph_pretty}")
graph_pretty:  torch.Size([2, 2, 6]), tensor([[[1, 2, 3, 1, 2, 3],
         [1, 2, 3, 4, 5, 6]],

        [[4, 5, 6, 1, 2, 3],
         [4, 5, 6, 4, 5, 6]]])

上面的变换,用一个函数来表示:

def single_graph(node: torch.Tensor):
    assert node.dim() == 2
    
    n = node.size(0)
    c = node.size(1)
    
    repeat_1 = node.repeat(1, n).view(-1, c)
    
    assert repeat_1.size(), (n*n, c)
    
    repeat_2 = node.repeat(n, 1)
    assert repeat_2.size() == (n*n, c)
    
    graph = torch.cat((repeat_1, repeat_2), dim=-1)
    
    assert graph.size() == (n*n, 2*c)
    
    graph = graph.view(n, n, 2*c)
    assert graph.size() == (n, n, 2*c)
    return graph

single_graph(node)
tensor([[[1, 2, 3, 1, 2, 3],
         [1, 2, 3, 4, 5, 6]],

        [[4, 5, 6, 1, 2, 3],
         [4, 5, 6, 4, 5, 6]]])

带有 batch size 的 Grap 构建

前面介绍了没有 batch size 的构建方式,但是 很多时候是有 batch size 的那么构建方式就发生了变化。

batch_node = torch.tensor([
                            [
                                [1, 2, 3],
                                [4, 5, 6]
                            ],
                            [
                                [7, 8, 9],
                                [10, 11, 12]
                            ]
                        ])

def batch_graph(batch_node: torch.Tensor):
    
    assert batch_node.dim() == 3
    
    batch_size = batch_node.size(0)
    n = batch_node.size(1)
    c = batch_node.size(2)
    
    print(f"batch node shape: {batch_node.size()}")
    
    # 对 node 进行repeat, batch_size 不变, c 进行 n 次重复
    
    repeat_1 = batch_node.repeat(1, 1, n)
    
    print(f"repeat_1 shape: {repeat_1.size()} \n {repeat_1}")
    
    # view 转换回正确的数据
    repeat_1 = repeat_1.view(-1, n*n, c)
    
    print(f"repeat_1 shape: {repeat_1.size()}")
    
    assert repeat_1.size() == (batch_size, n*n, c)

    repeat_2 = batch_node.repeat(1, n, 1)
    
    assert repeat_2.size() == (batch_size, n*n, c)
    
    graph = torch.cat((repeat_1, repeat_2), dim=-1)
    
    assert graph.size() == (batch_size, n*n, 2*c)
    
    graph = graph.view(-1, n, n, 2*c)
    
    return graph

graph = batch_graph(batch_node)

print(f"graph size: {graph.size()} \n {graph}")
    
batch node shape: torch.Size([2, 2, 3])
repeat_1 shape: torch.Size([2, 2, 6]) 
 tensor([[[ 1,  2,  3,  1,  2,  3],
         [ 4,  5,  6,  4,  5,  6]],

        [[ 7,  8,  9,  7,  8,  9],
         [10, 11, 12, 10, 11, 12]]])
repeat_1 shape: torch.Size([2, 4, 3])
graph size: torch.Size([2, 2, 2, 6]) 
 tensor([[[[ 1,  2,  3,  1,  2,  3],
          [ 1,  2,  3,  4,  5,  6]],

         [[ 4,  5,  6,  1,  2,  3],
          [ 4,  5,  6,  4,  5,  6]]],


        [[[ 7,  8,  9,  7,  8,  9],
          [ 7,  8,  9, 10, 11, 12]],

         [[10, 11, 12,  7,  8,  9],
          [10, 11, 12, 10, 11, 12]]]])

上一篇下一篇

猜你喜欢

热点阅读