深度学习AI人工智能与数学之美

pytorch 切片(下)

2020-08-18  本文已影响0人  zidea
slice.jpeg

使用 index_select 切分数据

下面介绍一个index_selectAPI 对 tensor 数据进行切分。

a.index_select(2,torch.arange(28)).shape
torch.Size([4, 3, 28, 28])

因为在 2 维度上选取0 -27 维度范围也就是没有进行任何切分效果因为数据在此维度上就是 28 维度

a.index_select(0,torch.tensor([0,2])).shape
torch.Size([2, 3, 28, 28])

习题

呵呵这里弄一个习题,给大家习题还是第一次,大家可以自己看看解释一下下面代码输出为什么是torch.Size([4, 3, 8, 28])

a.index_select(2,torch.arange(8)).shape
torch.Size([4, 3, 8, 28])

省略表示默认

a[...].shape

在计算机中编程一些符号有着其在编程中独特含义,不同语言也可能不同。但是合理的 API 设计是不用大家学习,一看就知道怎么用 API,这里省略号表示默认,所以这里省略号就是表示什么也不做,对数据没有切分操作。

torch.Size([4, 3, 28, 28])

我们看第一个,下面切分在第一位是 0 表示确定是第一张图片,后面省略好表示不做任何操作,所以切分出的数据是表示一张 28 \times 28 大小 3 通道的图片。

a[0,...].shape
torch.Size([3, 28, 28])

如果在第 2 维度上取 1 表示,表示数据 2 维度确定都是取图片 1 通道,从图片通道 RGB 来看,这里描述就是我们得到 3 张图片都是取 G 通道数据

a[:,1,...].shape
torch.Size([4, 28, 28])

接下来看这段代码

a[0,...,::2].shape

通过遮罩来筛选数据

x = torch.randn(3,4)

通过 randn 随机生产 3 \times 4 矩阵,然后通过条件进行筛选得到 mask 矩阵。在 mask 矩阵会根据条件生产一个矩阵,矩阵是由 True 和 False 来表示。

tensor([[-0.0040,  0.3439,  1.3629,  1.7692],
        [-0.1891, -2.1325, -0.9377, -0.3534],
        [-0.4318,  0.3152,  0.1341, -1.5351]])
mask = x.ge(0.5)
tensor([[False,  True,  True, False],
        [ True,  True, False, False],
        [False,  True,  True,  True]])

然后使用 masked_select 来利用 mask 进行筛选得到数据切分效果

torch.masked_select(x,mask)
tensor([1.3629, 1.7692])
上一篇 下一篇

猜你喜欢

热点阅读