torch.narrow()函数使用详解

2021-04-12  本文已影响0人  一位学有余力的同学

torch.narrow()函数是用来返回Tensor的切片的,它的使用方法如下:

torch.narrow(input, dim, start, length)

  • input– 待处理的tensor
  • dim – 维度,当为0时以行为单位进行切片,当为1时以列为单位进行切片
  • start – 切片开始的索引
  • length – 切片的长度

下面用一个例子来加以说明:

>>> a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
>>> torch.narrow(a, 0, 0, 2)
tensor([[1, 2, 3],
        [4, 5, 6]])
>>> torch.narrow(a, 1, 1, 2)
tensor([[2, 3],
        [5, 6],
        [8, 9]])

如果不传入input,也可以直接对tensor进行操作:

>>> a.narrow(0,0,2)
tensor([[1, 2, 3],
        [4, 5, 6]])

同时,Numpy上的快速切片方法在Pytorch上也同样适用:

>>> a[:,0:2]
tensor([[1, 2],
        [4, 5],
        [7, 8]])

参考:
pytorch官网

上一篇 下一篇

猜你喜欢

热点阅读