pytorch repeat 解析
pytorch repeat 解析
pytorch 中 Tensor.repeat
函数,能够将一个 tensor 从不同的维度上进行重复。这个能力在 Graph Attention Networks
中,有着应用。现在来看下,repeat 的能力是如何工作的?
repeat(*sizes) → Tensor
* sizes (torch.Size or int...) – The number of times to repeat this tensor along each dimension
翻译过来:
repeat
会将Tensor 在指定的维度方向上进行重复。比如设置参数是 2, 3, 4
: 表示在 0 维方向上 重复2次,1 维方向上重复 3次, 2 维方向上重复4次。 注意这里的 2, 3, 4
不是指的维度方向,而是 0:2, 1:3, 2:4
在不同的维度上重复的次数。同时,也会进行维度的扩充。
import torch
def repeat_1():
x = torch.tensor([1, 2, 3])
print(f"x shape: {x.size()} : {x}")
print(f"在 0 维上 重复2次 ---")
xx = x.repeat(2)
print(f"xx shape: {xx.size()}, {xx}")
print(f"在 0 维上重复2次, 1 维上重复3次")
xx = x.repeat(2, 3)
print(f"xx shape: {xx.size()}, {xx}")
repeat_1()
x shape: torch.Size([3]) : tensor([1, 2, 3])
在 0 维上 重复2次 ---
xx shape: torch.Size([6]), tensor([1, 2, 3, 1, 2, 3])
在 0 维上重复2次, 1 维上重复3次
xx shape: torch.Size([2, 9]), tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3, 1, 2, 3]])
上面演示的是 x 只有一个维度,现在演示 2 个维度的。比如在 "Graph Attention Networds", 需要计算 图的 attention, 那么需要将图上所有的节点进行两两拼接。也就是说:
经过两两拼接后, 形成的拼接 graph, 如下:
其中 表示 表示两个向量拼接,拼接后的维度是
现在,当我们有了一个 的矩阵,如何拼接出 就用到了 repeate.
没有 batch 的情况
, 经过转换后
开始时候:
其中
在下面的例子中
import torch
# node = [a1, a2]
node = torch.tensor([
[1, 2, 3],
[4, 5, 6]
])
n = node.size(0)
c = node.size(-1)
print(f"node: {node.size()}, {node}")
node: torch.Size([2, 3]), tensor([[1, 2, 3],
[4, 5, 6]])
现将 变成:
沿着 n 的方向上重复 n次, c 的方向上不便
重复完的 node_repeat_1 shape: (n*n, c)
也就是 (4, 3)
node_repeat_1 = node.repeat(n, 1)
print(f"node_repeat_1 shape: {node_repeat_1.size()}, {node_repeat_1}")
node_repeat_1 shape: torch.Size([4, 3]), tensor([[1, 2, 3],
[4, 5, 6],
[1, 2, 3],
[4, 5, 6]])
想要链接在一起来不够,还要产生一个 ,
这样 和 经过 concat
操作就能够得到我们需要的 了。
直接做这件事,需要一点技巧,在 c 这个方向上重复 n 次,然后在做一个 view变换。
node_repeat_2_tmp = node.repeat(1, n)
print(f"node_repeat_2_tmp: {node_repeat_2_tmp.size()}, {node_repeat_2_tmp}")
node_repeat_2_tmp: torch.Size([2, 6]), tensor([[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6]])
node_repeat_2 = node_repeat_2_tmp.view(-1, c)
print(f"node_repeat_2: {node_repeat_2.size()}, {node_repeat_2}")
node_repeat_2: torch.Size([4, 3]), tensor([[1, 2, 3],
[1, 2, 3],
[4, 5, 6],
[4, 5, 6]])
最后,将 node_repeat_1 与 node_repeat_2 concat 在一起就是 graph了。 注意: node_repeat_2 在前面, node_repeat_1在后面,因为 graph 第 i 行是:
所以,需要 与其他的所有相连接,所以需要 node_repeat_2 在前面。
graph = torch.cat((node_repeat_2, node_repeat_1), dim=-1)
print(f"graph: {graph.size()}, {graph}")
graph: torch.Size([4, 6]), tensor([[1, 2, 3, 1, 2, 3],
[1, 2, 3, 4, 5, 6],
[4, 5, 6, 1, 2, 3],
[4, 5, 6, 4, 5, 6]])
graph_pretty = graph.view(n, n, 2 * c)
print(f"graph_pretty: {graph_pretty.size()}, {graph_pretty}")
graph_pretty: torch.Size([2, 2, 6]), tensor([[[1, 2, 3, 1, 2, 3],
[1, 2, 3, 4, 5, 6]],
[[4, 5, 6, 1, 2, 3],
[4, 5, 6, 4, 5, 6]]])
上面的变换,用一个函数来表示:
def single_graph(node: torch.Tensor):
assert node.dim() == 2
n = node.size(0)
c = node.size(1)
repeat_1 = node.repeat(1, n).view(-1, c)
assert repeat_1.size(), (n*n, c)
repeat_2 = node.repeat(n, 1)
assert repeat_2.size() == (n*n, c)
graph = torch.cat((repeat_1, repeat_2), dim=-1)
assert graph.size() == (n*n, 2*c)
graph = graph.view(n, n, 2*c)
assert graph.size() == (n, n, 2*c)
return graph
single_graph(node)
tensor([[[1, 2, 3, 1, 2, 3],
[1, 2, 3, 4, 5, 6]],
[[4, 5, 6, 1, 2, 3],
[4, 5, 6, 4, 5, 6]]])
带有 batch size 的 Grap 构建
前面介绍了没有 batch size 的构建方式,但是 很多时候是有 batch size 的那么构建方式就发生了变化。
batch_node = torch.tensor([
[
[1, 2, 3],
[4, 5, 6]
],
[
[7, 8, 9],
[10, 11, 12]
]
])
def batch_graph(batch_node: torch.Tensor):
assert batch_node.dim() == 3
batch_size = batch_node.size(0)
n = batch_node.size(1)
c = batch_node.size(2)
print(f"batch node shape: {batch_node.size()}")
# 对 node 进行repeat, batch_size 不变, c 进行 n 次重复
repeat_1 = batch_node.repeat(1, 1, n)
print(f"repeat_1 shape: {repeat_1.size()} \n {repeat_1}")
# view 转换回正确的数据
repeat_1 = repeat_1.view(-1, n*n, c)
print(f"repeat_1 shape: {repeat_1.size()}")
assert repeat_1.size() == (batch_size, n*n, c)
repeat_2 = batch_node.repeat(1, n, 1)
assert repeat_2.size() == (batch_size, n*n, c)
graph = torch.cat((repeat_1, repeat_2), dim=-1)
assert graph.size() == (batch_size, n*n, 2*c)
graph = graph.view(-1, n, n, 2*c)
return graph
graph = batch_graph(batch_node)
print(f"graph size: {graph.size()} \n {graph}")
batch node shape: torch.Size([2, 2, 3])
repeat_1 shape: torch.Size([2, 2, 6])
tensor([[[ 1, 2, 3, 1, 2, 3],
[ 4, 5, 6, 4, 5, 6]],
[[ 7, 8, 9, 7, 8, 9],
[10, 11, 12, 10, 11, 12]]])
repeat_1 shape: torch.Size([2, 4, 3])
graph size: torch.Size([2, 2, 2, 6])
tensor([[[[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 4, 5, 6]],
[[ 4, 5, 6, 1, 2, 3],
[ 4, 5, 6, 4, 5, 6]]],
[[[ 7, 8, 9, 7, 8, 9],
[ 7, 8, 9, 10, 11, 12]],
[[10, 11, 12, 7, 8, 9],
[10, 11, 12, 10, 11, 12]]]])