pytorch 函数理解
1、torch.nn.Unfold
函数作用:
unfold 是展开的意思,在 torch 中则是只卷不积,相当于只滑窗,不进行元素相乘
参数:
kernel_size: _size_any_t, 卷积核的大小
dilation: _size_any_t=1, 卷积核元素之间的空洞个数
padding: _size_any_t=0, 填充特征四周的列数,默认为 0,则不填充
stride: _size_any_t=1,卷积核移动的步长
函数理解:
参考资料:
PYTORCH实现手动滑窗,卷积(利用UNFOLD,FOLD操作)
unfold 过程:
① 对于 batch 里的每个数据分别进行 unfold
② 分别在每个数据的每个通道上,使用大小为 k*k 的卷积核进行从左往右,从上向下的滑窗
③ 对于在每个通道上分别得到的第一个滑窗区域,分别进行 reshape 成行向量,然后把在所有通道上得到的行向量,进行横向拼接,得到新的行向量
④ 对于在每个通道上得到的滑窗区域都进行步骤 ③ 的操作,直到所有的滑窗区域都处理完
⑤ 将步骤 ③ 和 步骤 ④ 中得到的行向量,进行纵向拼接,得到一个矩阵
⑥ 完成 unfold 操作,将 batch 中每个数据进行 unfold 得到的矩阵进行堆放,得到输出结果
例子:
x = torch.range(1, 2*3*4*5)
print(x.shape)
batch_x = x.reshape([2, 3, 4, 5])
print(batch_x.shape)
# unfold 是展开的意思,在 torch 中则是只卷不积,相当于只滑窗,不进行元素相乘
unfold = torch.nn.Unfold(3)
res = unfold(batch_x)
print(res.shape)
结果:
torch.Size([2, 27, 6])
分析:
假设输入的 batch_x 维度为 [2, 3, 4, 5],其中 2 是批的数据量大小 B, 3 是通道数 C,4 是高度 H,5 是宽度 W 。使用的卷积核大小 K 为 3*3,移动步长 S 为 1,padding 为 0
① 在 B 的每个数据上进行 unfold
② 同时在每个通道上的最左上角开始进行滑动,对于每个通道,得到大小为 9 的滑动区域,然后进行 Reshape 成维度为 [1, 9] 的行向量。然后将在所有 3 个通道上得到的 3 个行向量,进行横向拼接,得到维度为 [1, 27] 的行向量。
③ 依次将卷积核按照从左到右,从上往下的顺序,按照步长 1 进行滑动,每个滑动的区域经过步骤 ② 中处理后都能得到一个维度为 [1, 27] 的行向量,共得到 6 个维度为 [1, 27],然后纵向堆叠成维度为 [6, 27] 的矩阵
④ 将每个数据经过 unfold 得到的维度为 [6, 27] 的矩阵进行堆叠成维度为 [2, 27, 6] 的张量