2. Pytorch中torch.index_select

2021-03-22  本文已影响0人  yoyo9999

torch.index_select(input, dim, index, *, out=None) → Tensor

作用是:

Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a LongTensor.

返回按照相应维度的给定index的选取的元素,index必须是longtensor。

x = torch.rand((4, 3, 2))
print('x:',x)
indices0 = torch.LongTensor([0, 1])
x_0 = torch.index_select(x, dim=0, index=indices0)
print("x_0:", x_0)
indices1 = torch.LongTensor([1, 2])
x_1 = torch.index_select(x, dim=1, index=indices1)
print("x_1:", x_1)

++++++++++++++++++++++++++++++++++++++++++
x: tensor([[[0.9854, 0.4894],
         [0.3774, 0.6066],
         [0.5971, 0.7116]],

        [[0.0447, 0.9854],
         [0.6996, 0.1671],
         [0.4965, 0.5742]],

        [[0.9878, 0.9571],
         [0.9090, 0.5475],
         [0.6792, 0.4184]],

        [[0.2394, 0.9625],
         [0.1951, 0.2918],
         [0.3154, 0.2175]]])
x_0: tensor([[[0.9854, 0.4894],
         [0.3774, 0.6066],
         [0.5971, 0.7116]],

        [[0.0447, 0.9854],
         [0.6996, 0.1671],
         [0.4965, 0.5742]]])
x_1: tensor([[[0.3774, 0.6066],
         [0.5971, 0.7116]],

        [[0.6996, 0.1671],
         [0.4965, 0.5742]],

        [[0.9090, 0.5475],
         [0.6792, 0.4184]],

        [[0.1951, 0.2918],
         [0.3154, 0.2175]]])


上一篇 下一篇

猜你喜欢

热点阅读