机器学习与深度学习

pytorch 函数理解

2021-04-27  本文已影响0人  LCG22

1、torch.nn.Unfold

函数作用:

unfold 是展开的意思,在 torch 中则是只卷不积,相当于只滑窗,不进行元素相乘

参数:

kernel_size: _size_any_t, 卷积核的大小

dilation: _size_any_t=1, 卷积核元素之间的空洞个数

padding: _size_any_t=0, 填充特征四周的列数,默认为 0,则不填充

stride: _size_any_t=1,卷积核移动的步长

函数理解:

参考资料:

PYTORCH实现手动滑窗,卷积(利用UNFOLD,FOLD操作)

unfold 过程:

① 对于 batch 里的每个数据分别进行 unfold

② 分别在每个数据的每个通道上,使用大小为 k*k 的卷积核进行从左往右,从上向下的滑窗

③ 对于在每个通道上分别得到的第一个滑窗区域,分别进行 reshape 成行向量,然后把在所有通道上得到的行向量,进行横向拼接,得到新的行向量

④ 对于在每个通道上得到的滑窗区域都进行步骤 ③ 的操作,直到所有的滑窗区域都处理完

⑤ 将步骤 ③ 和 步骤 ④ 中得到的行向量,进行纵向拼接,得到一个矩阵

⑥ 完成 unfold 操作,将 batch 中每个数据进行 unfold 得到的矩阵进行堆放,得到输出结果

例子:

x = torch.range(1, 2*3*4*5)

print(x.shape)

batch_x = x.reshape([2, 3, 4, 5])

print(batch_x.shape)

# unfold 是展开的意思,在 torch 中则是只卷不积,相当于只滑窗,不进行元素相乘

unfold = torch.nn.Unfold(3)

res = unfold(batch_x)

print(res.shape)

结果:

torch.Size([2, 27, 6])

分析:

假设输入的 batch_x 维度为 [2, 3, 4, 5],其中 2 是批的数据量大小 B, 3 是通道数 C,4 是高度 H,5 是宽度 W 。使用的卷积核大小 K 为 3*3,移动步长 S 为 1,padding 为 0

① 在 B 的每个数据上进行 unfold

② 同时在每个通道上的最左上角开始进行滑动,对于每个通道,得到大小为 9 的滑动区域,然后进行 Reshape 成维度为 [1, 9] 的行向量。然后将在所有 3 个通道上得到的 3 个行向量,进行横向拼接,得到维度为 [1, 27] 的行向量。

③ 依次将卷积核按照从左到右,从上往下的顺序,按照步长 1 进行滑动,每个滑动的区域经过步骤 ② 中处理后都能得到一个维度为 [1, 27] 的行向量,共得到 6 个维度为 [1, 27],然后纵向堆叠成维度为 [6, 27] 的矩阵

④ 将每个数据经过 unfold 得到的维度为 [6, 27]  的矩阵进行堆叠成维度为 [2, 27, 6]  的张量

上一篇下一篇

猜你喜欢

热点阅读