pytorch

PyTorch的 transpose、permute、view、

2020-04-29  本文已影响0人  top_小酱油

基础环境

Python 3.6.9 
GCC 8.3.0 on linux
PyTorch 1.4.0
>>> import torch
>>> print(torch.__version__) 
1.4.0

transpose() 和 permute()的区别

  • transpose并不改变a本身的形状,将改变的一个副本赋值给b,相当于先拷贝了一份,然后再改变这份拷贝的。
  • permute() 和 tranpose() 比较相似,transpose是交换两个维度,permute()是交换多个维度。
a = torch.randn(1, 2, 3, 4)
b = a.transpose(1, 2)
print(a.shape)#torch.Size([1, 2, 3, 4])
print(b.shape)#torch.Size([1, 3, 2, 4])

x = torch.randn(2, 3, 5)
y=x.permute(2, 0, 1)
print(y.shape)#torch.Size([5, 2, 3])

transpose()和view()的区别

  • b和c的形状虽然相同,但内容是不相等的
  • transpose的改变不等于view的改变
  • 一个不同之处在于view()只能对连续的张量进行操作,并且返回的张量仍然是连续的。transpose()既可以在连续张量上操作,也可以在非连续张量上操作。与view()不同,返回的张量可能不再是连续的。
a = torch.randn(1, 2, 3, 4)
b = a.transpose(1, 2)
print(a.shape)#torch.Size([1, 2, 3, 4])
print(b.shape)#[1, 3, 2, 4]
c = a.view(1, 3, 2, 4)
print(torch.equal(b, c))#False
print(c.shape)#torch.Size([1, 3, 2, 4])

针对连续的讨论

transpose与view 分别对tensor做了什么样的改变

import torch
x = torch.Tensor([[1,2,3],[4,5,6]])
print(x.shape)
y = x.view(3,2)
print(y.shape)
z = x.transpose(1,0)
print(z.shape)
print(x)
print(y)
print(z)
torch.Size([2, 3])
torch.Size([3, 2])
torch.Size([3, 2])
tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([[1., 2.],
        [3., 4.],
        [5., 6.]])
tensor([[1., 4.],
        [2., 5.],
        [3., 6.]])

这根上面的b和c的形状虽然相同,但内容是不相等的 是一样的道理

reshape()与view()的区别

  • reshape返回一个张量,该张量具有与自身相同的数据和元素数量,但具有指定的形状。如果Shape与当前形状兼容,则此方法返回一个view。有关何时可以返回view的信息。
  • reshape 封装了 view,view根据规则有时还需要调用contiguous()

permute().contiguous().view()相当于reshape

  • view返回的Tensor底层数据不会使用新的内存,如果在view中调用了contiguous方法,则可能在返回Tensor底层数据中使用了新的内存,PyTorch又提供了reshape方法,实现了类似于 contigous().view()的功能,使用reshape更方便.
  • contiguous 一般用于 transpose/permute 后和 view 前,即使用 transpose 或 permute 进行维度变换后,调用 contiguous,然后方可使用 view 对维度进行变形(如:tensor_var.contiguous().view() )

原文链接:https://blog.csdn.net/flyfish1986/article/details/105054982

上一篇下一篇

猜你喜欢

热点阅读