自然语言处理知识图谱

一文搞懂池化层!Pooling详解(魔改篇)

2022-07-12  本文已影响0人  晓柒NLP与药物设计

一. Overlapping Pooling(重叠池化)

重叠池化正如其名字所说的,相邻池化窗口之间会有重叠区域,此时sizeX > stride

提出于ImageNet Classification with Deep Convolutional Neural Networks

二. 空金字塔池化(Spatial Pyramid Pooling)

空间金字塔池化可以把任何尺度的图像的卷积特征转化成相同维度,这不仅可以让CNN处理任意尺度的图像,还能避免cropping和warping操作,导致一些信息的丢失,具有非常重要的意义

空间金字塔池化:先让图像进行卷积操作,然后转化成维度相同的特征输入到全连接层,把CNN扩展到任意大小的图像

1. SPP显著特点
  1. 不管输入尺寸是怎样,SPP可以产生固定大小的输出
  2. 使用多个窗口(pooling window)
  3. SPP可以使用同一图像不同尺寸(scale)作为输入, 得到同样长度的池化特征
2. SPP优势
  1. 由于对输入图像的不同纵横比和不同尺寸,SPP同样可以处理,所以提高了图像的尺度不变(scale-invariance)和降低了过拟合(over-fitting)
  2. 实验表明训练图像尺寸的多样性比单一尺寸的训练图像更容易使得网络收敛(convergence)
  3. SPP对于特定的CNN网络设计和结构是独立的
  4. 不仅可以用于图像分类而且可以用来目标检测
3. 步骤

输入一张任意尺寸的图片的时候,利用不同网格大小的池化核,对一张图片进行了划分:

  1. 第1个SPP池化核=(4*4),把一张完整的任意长度的输入图片,分成16个块,也就是每个块的大小就是(w/4,h/4)
  2. 第2个SPP池化核=(2*2),把一张完整的任意长度的输入图片,分成4个块,也就是每个块的大小就是(w/2,h/2)
  3. 第3个SPP池化核=(1*1),把一张完整的任意长度的输入图片,分成1个块,也就是每个块的大小就是(w,h)

经过上述的切分,最后共得到21个块,即一个固定长度的特殊输出

最大池化特征提取

对每个网格,按照各自占用输入特征的范围,各自独立的进行池化,支持的算法算法有:

输出21个输出特征值组成了最终的输出。这样就可以把任意长度的特征值,转化为一个固定大小的特征值了,由于SPP采用了三层不同的池化核,因此转化后的特征值,既包含了高层抽象特征,也包含了低层的具体特征

import torch
import torch.nn.functional as F
# 构建SPP层(空间金字塔池化层)
class SPPLayer(torch.nn.Module):
    def __init__(self, num_levels, pool_type='max_pool'):
        super(SPPLayer, self).__init__()
        self.num_levels = num_levels
        self.pool_type = pool_type
    def forward(self, x):
        num, c, h, w = x.size() # num:样本数量 batch_size c:通道数 h:高 w:宽
        for i in range(self.num_levels):
            level = i+1
            kernel_size = (math.ceil(h / level), math.ceil(w / level))
            stride = (math.ceil(h / level), math.ceil(w / level))
            pooling = (math.floor((kernel_size[0]*level-h+1)/2), math.floor((kernel_size[1]*level-w+1)/2))
            # 选择池化方式
            if self.pool_type == 'max_pool':
                tensor = F.max_pool2d(x, kernel_size=kernel_size, stride=stride, padding=pooling).view(num, -1)
            else:
                tensor = F.avg_pool2d(x, kernel_size=kernel_size, stride=stride, padding=pooling).view(num, -1)
            # 展开、拼接
            if (i == 0):
                x_flatten = tensor.view(num, -1)
            else:
                x_flatten = torch.cat((x_flatten, tensor.view(num, -1)), 1)
        return x_flatten
上一篇下一篇

猜你喜欢

热点阅读