torch.squeeze()和torch.unsqueeze(
2019-07-23 本文已影响0人
西北小生_
1. torch.squeeze(tensor)
和numpy等库函数中的squeeze()函数作用一样,torch.squeeze()函数的作用是压缩一个tensor的维数为1的维度,使该tensor降维变成最紧凑的形式:
In [1]: import numpy as np
In [2]: import torch
In [3]: a = torch.arange(9).view(3,1,3)
In [4]: a
Out[4]:
tensor([[[0, 1, 2]],
[[3, 4, 5]],
[[6, 7, 8]]])
In [5]: a.size()
Out[5]: torch.Size([3, 1, 3])
In [6]: a.dim()
Out[6]: 3
In [7]: b = torch.squeeze(a)
In [8]: b
Out[8]:
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
In [9]: b.size()
Out[9]: torch.Size([3, 3])
In [10]: b.dim()
Out[10]: 2
同样numpy中功能一样:
In [11]: c = np.arange(9).reshape(1,3,1,3)
In [12]: c
Out[12]:
array([[[[0, 1, 2]],
[[3, 4, 5]],
[[6, 7, 8]]]])
In [13]: c.shape, c.ndim
Out[13]: ((1, 3, 1, 3), 4)
In [14]: d = np.squeeze(c)
In [15]: d
Out[15]:
array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
In [16]: d.shape, d.ndim
Out[16]: ((3, 3), 2)
2. torch.unsqueeze(tensor, dim)
unsqueeze()函数的功能是在tensor的某个维度上添加一个维数为1的维度,这个功能用view()函数也可以实现。这一功能尤其在神经网络输入单个样本时很有用,由于pytorch神经网络要求的输入都是mini-batch型的,维度为[batch_size, channels, w, h],而一个样本的维度为[c, w, h],此时用unsqueeze()增加一个维度变为[1, c, w, h]就很方便了。
In [17]: b
Out[17]:
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
In [18]: b.size(), b.dim()
Out[18]: (torch.Size([3, 3]), 2)
In [20]: b_un = torch.unsqueeze(b, 0)
In [21]: b_un
Out[21]:
tensor([[[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]])
In [22]: b_un.size(), b_un.dim()
Out[22]: (torch.Size([1, 3, 3]), 3)
In [23]: b_un_un = torch.unsqueeze(b_un, 3)
In [24]: b_un_un
Out[24]:
tensor([[[[0],
[1],
[2]],
[[3],
[4],
[5]],
[[6],
[7],
[8]]]])
In [25]: b_un_un.size(), b_un_un.dim()
Out[25]: (torch.Size([1, 3, 3, 1]), 4)