PyTorch Gather 函数

2022-02-14  本文已影响0人  数科每日

PyTorch 的 Gather 函数很实用,但是理解起来有些困难,本文试图用图例和代码给出解释。 完整代码

Gather 主要有三个参数

Gather 函数返回值和 index 相同

Dim=0
Dim=0
dim = 0
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1, 2], [1, 2, 0]])
                     
output = torch.gather(input, dim, index)
output

Dim = 0 的时候, 从外层选择, 最内层的 list Tensor 会被拆开:

image.png
Dim=1
Dim=1
dim = 1
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1], [1, 2], [2, 0]])

output = torch.gather(input, dim, index)
output

Dim = 1 的时候, 从内层选择:

image.png
上一篇下一篇

猜你喜欢

热点阅读