机器学习

神经网络参数量计算

2020-09-06  本文已影响0人  ce0b74704937

这篇文章主要是记录一下神经网络的参数量大致估算方法,计算过程利用的是mobilenet-v2来举例说明,如果对mobilenet-v2不太了解,可以参考文章mobilenet-v1和mobilenet-v2详解

关于卷积参数计算的方式

假设输入feature(如果是第一层就是图片了)的维度为N_{input} \times C_{input} \times H_{input} \times W_{input}
经过卷积后输出的feature的维度为N_{output} \times C_{output} \times H_{output} \times W_{output}

普通卷积

对于普通卷积来说,假设卷积的kernel大小为K \times K,该层卷积的参数量为C_{input} \times C_{output} \times K \times K

注意:因为现有的很多网络,卷积后都会跟上BN操作,所以一般卷积都不会加bias,所以这里都假设无bias的情况,mobilenet-v2里面的卷积是没有的

depthwise卷积

对于mobilenet来说是一个比较轻量级的网络,主要原因是利用了depthwise conv操作。对于depthwise conv来说,经过该类卷积后,会有C_{output} = C_{input},也就是卷积的通道数和输入通道数是一样的。这时,该卷积的参数量为C_{input} \times K \times K

注意:拿TensorFlow来说,里面有个接口tf.nn.depthwise_conv2d(
input, filter, strides, padding, rate=None, name=None, data_format=None,
dilations=None
)
,这个接口的输入filter参数包含一个叫channel_multiplier,这个参数可以使得C_{output} = channel\_multiplier * C_{input}。上述说的是原始的depthwise(C_{output} = C_{input}),也是mobilenet里面使用的

moblienet-v2结构

mobilenet-v2 blockneck

每一层参数计算

这里的一层对应上表每一行

输入尺寸 参数量 详情
conv2d 224,224,3 864 3 * 3 * 3 * 32
blockneck 112,112,32 1824 blockneck(1,32,16)
blockneck 112,112,16 12912 blockneck(6,16,24) + bottleneck(6,24,24)
blockneck 56,56,24 37392 blockneck(6,24,32) + bottleneck(6,32,32) * 2
blockneck 28,28,32 177984 blockneck(6,32,64) + bottleneck(6,64,64)* 3
blockneck 14,14,64 281796 blockneck(6,53,96) + bottleneck(6,96,96)* 2
blockneck 14,14,96 784320 blockneck(6,96,160) + bottleneck(6,160,160)* 2
blockneck 7,7,160 469440 blockneck(6,160,320) * 1
conv2d 7,7,320 409600 1 * 1 * 320 * 1280
conv2d 1,1,1280 1280000 1280 * 1000
总量 \ 3456132 上述所有行相加

其中表格里的blockneck是一个函数,实现如下:

def blockneck(t, input, output):
    """
    args:
        t: 表示blockneck中的扩展因子(expansion factor)
        input:表示输入feature的通道数
        output: 表示输出feature的通道数
    return:
        返回一个blockneck的参数量
    """
    conv_1 = 1 * 1 * (input**2) * t
    dwise = 3 * 3 * input * t
    conv_2 = 1 * 1 * (input * t) * output
    return conv_1 + dwise + conv_2

还要补充说明的是

  1. 上述计算没有加入bias项、bn等
  2. 上述结果是参数量,并不是参数存储所需空间,拿float来存储参数来说,还需要乘以4(float占用4字节)。
上一篇下一篇

猜你喜欢

热点阅读