Pytorch里面的一些小细节

2020-04-10  本文已影响0人  吱吱加油

[ ]表示列表数据

{}表示字典数据

numpy.arange(n).reshape((z,r,l))

其中arange中的n表示取0~n的数,z表示第三个维度通道数,r表示行,l表示列

numpy.transpose(a,b,c)

其中行表示0,列表示1,第三维度即通道数表示2

即若要让其不变即

transpose(0,1,2)


import numpy as np

x=np.arange(20).reshape((2,2,5))

y=x.transpose(0,1,2)

print('X equals:\n',x)

print('Y equals:\n',y)


X equals:

[[[ 0  1  2  3  4]

  [ 5  6  7  8  9]]

[[10 11 12 13 14]

  [15 16 17 18 19]]]

Y equals:

[[[ 0  1  2  3  4]

  [ 5  6  7  8  9]]

[[10 11 12 13 14]

  [15 16 17 18 19]]]


y=x.transpose(2,0,1)

Y equals:

[[[ 0  5]

  [10 15]]

[[ 1  6]

  [11 16]]

[[ 2  7]

  [12 17]]

[[ 3  8]

  [13 18]]

[[ 4  9]

  [14 19]]]


详细解释可以看看这个numpy中的transpose函数使用方法 - 学弟1 - 博客园


torch.squeeze:移除长度为1的维度,也可以指定移除某一维度,但是该维度只有长度为1时,才能够被移除


import torch

t = torch.rand((2,5,3,1))

t_0 = torch.squeeze(t)

t_1 = torch.squeeze(t,0)

t_2 = torch.squeeze(t,1)

t_3 = torch.squeeze(t,2)

t_4 = torch.squeeze(t,3)

print(t.shape)

print(t_0.shape)

print(t_1.shape)

print(t_2.shape)

print(t_3.shape)

print(t_4.shape)


torch.Size([2, 5, 3, 1])

torch.Size([2, 5, 3])

torch.Size([2, 5, 3, 1])

torch.Size([2, 5, 3, 1])

torch.Size([2, 5, 3, 1])

torch.Size([2, 5, 3])


torch.unsqueeze():用于数据维度的增加


import torch

a = torch.randn(1,3)

print(a)

print(a.shape)

b= torch.unsqueeze(a,1)

print(b)

print(b.shape)

c=a.unsqueeze(0)

print(c)

print(c.shape)


result:

tensor([[-0.4639, -0.7896, 1.1053]])

torch.Size([1, 3])

tensor([[[-0.4639, -0.7896,  1.1053]]])

torch.Size([1, 1, 3])

tensor([[[-0.4639, -0.7896,  1.1053]]])

torch.Size([1, 1, 3])


上一篇 下一篇

猜你喜欢

热点阅读