Pytorch

torch.squeeze()和torch.unsqueeze(

2019-07-23  本文已影响0人  西北小生_
1. torch.squeeze(tensor)

和numpy等库函数中的squeeze()函数作用一样,torch.squeeze()函数的作用是压缩一个tensor的维数为1的维度,使该tensor降维变成最紧凑的形式:

In [1]: import numpy as np                                                      

In [2]: import torch                                                            

In [3]: a = torch.arange(9).view(3,1,3)                                         

In [4]: a                                                                       
Out[4]: 
tensor([[[0, 1, 2]],

        [[3, 4, 5]],

        [[6, 7, 8]]])

In [5]: a.size()                                                                
Out[5]: torch.Size([3, 1, 3])

In [6]: a.dim()                                                                 
Out[6]: 3

In [7]: b = torch.squeeze(a)                                                    

In [8]: b                                                                       
Out[8]: 
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

In [9]: b.size()                                                                
Out[9]: torch.Size([3, 3])

In [10]: b.dim()                                                                
Out[10]: 2

同样numpy中功能一样:

In [11]: c = np.arange(9).reshape(1,3,1,3)                                      

In [12]: c                                                                      
Out[12]: 
array([[[[0, 1, 2]],

        [[3, 4, 5]],

        [[6, 7, 8]]]])

In [13]: c.shape, c.ndim                                                        
Out[13]: ((1, 3, 1, 3), 4)

In [14]: d = np.squeeze(c)                                                      

In [15]: d                                                                      
Out[15]: 
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])

In [16]: d.shape, d.ndim                                                        
Out[16]: ((3, 3), 2)
2. torch.unsqueeze(tensor, dim)

unsqueeze()函数的功能是在tensor的某个维度上添加一个维数为1的维度,这个功能用view()函数也可以实现。这一功能尤其在神经网络输入单个样本时很有用,由于pytorch神经网络要求的输入都是mini-batch型的,维度为[batch_size, channels, w, h],而一个样本的维度为[c, w, h],此时用unsqueeze()增加一个维度变为[1, c, w, h]就很方便了。

In [17]: b                                                                      
Out[17]: 
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

In [18]: b.size(), b.dim()                                                      
Out[18]: (torch.Size([3, 3]), 2)

In [20]: b_un = torch.unsqueeze(b, 0)                                           

In [21]: b_un                                                                   
Out[21]: 
tensor([[[0, 1, 2],
         [3, 4, 5],
         [6, 7, 8]]])

In [22]: b_un.size(), b_un.dim()                                                
Out[22]: (torch.Size([1, 3, 3]), 3)

In [23]: b_un_un = torch.unsqueeze(b_un, 3)                                     

In [24]: b_un_un                                                                
Out[24]: 
tensor([[[[0],
          [1],
          [2]],

         [[3],
          [4],
          [5]],

         [[6],
          [7],
          [8]]]])

In [25]: b_un_un.size(), b_un_un.dim()                                          
Out[25]: (torch.Size([1, 3, 3, 1]), 4)
上一篇下一篇

猜你喜欢

热点阅读