Xception 算法

2018-12-03  本文已影响35人  LuDon

引言

Xception是google在inception之后提出的对inceptionV3的另一种改进,主要采用depthwise separable convolution来替换原来的inception v3中的卷积操作。

思考

要解决什么问题?怎么解决的?

效果怎么样?

还存在什么问题?

相关知识

Xception是inception系列的成员之一。
inception与普通的卷积操作相比,具有更强的表达能力。

inception系列的复习

inception结构如图1所示,用三种conv计算合并之后代替原来的conv。

图1.inception结构
选用卷积核1 图2.降维后的inception结构

两个3 \times 3的卷积核可以替代5 \times 5的卷积核,因此结构变为图3。

图3
以上模块主要在inceptionV3中,inceptionV3的基本结构为:
input
conv2d(32, 3, 3, s=2) #conv2d_1a
conv2d(32, 3, 3) #conv2d_2a
conv2d(64, 3, 3, 'SAME') #conv2d_2b
max_pool2d(3, 3, ,s=2) #maxpool_3a
conv2d(1, 1, 80) #conv2d_3b
conv2d(3, 3, 192) #conv2d_4a
max_pool2d(3, 3, s=2) #maxpool_5a

conv2d(1, 1, 64)     conv2d(1, 1, 48)      conv2d(1, 1, 64)    avgpool(3, 3)
                     conv2d(5, 5, 64)      conv2d(3, 3, 96)     conv2d(1, 1, 32)
                                           conv2d(3, 3, 96)

concat
*9
conv2d(1, 1, num_class)

在以上模块中,对于一个conv层来说,需要学习的是一个3D的卷积核,其中包括两个空间维度和一个通道维度,即w,h,c。这个卷积核与输入在3个维度上进行卷积操作,得到最终的结果,伪代码如下:

// 对于第i个filter
// 计算输入中心点(x, y)对应的卷积结果
sum = 0
for c in 1:C
  for h in 1:K
    for w in 1:K
      sum += input[c, y-K/2+h, x-K/2+w] * filter_i[c, h, w]
out[i, y, x] = sum

可以看出在3D的卷积中,通道这个维度与空间的两个维度是一样的。

先用一个统一的1 \times 1的卷积核卷积,然后连接三个3 \times 3的卷积核,如图4所示。这3个卷积操作只将前面的1 \times 1卷积结果中的一部分作为自己的输入。图中是将1/3通道作为每个卷积核的输入。

图4

再将3 \times 3卷积核的个数延伸到与1 \times 1卷积核输出通道的个数一样,即每个3 \times 3的卷积核和1个输入通道做卷积,如图5所示。

图5

Xception

Xception主要使用depthwise separable convolution,即将传统的卷积操作分成两步:

图6

depthwise separable convolution和以上结构的不同之处:

Xception结构是将ResNet的相关卷积变成了depthwise separable conv,如下图所示。其中SeparableConv是depthwise separable conv模块。另外,原来的concat变成了residual connection。


Xception结构图

参考文献

[1] Xception: Deep Learning with Depthwise Separable Convolutions

代码分析

### Xception.py
from keras.preprocessing import image
from keras.models import Model
from keras import layers
from keras.layers import Dense
from keras.layers import Input
from keras.layers import BatchNormalization
from keras.layers import Activation
from keras.layers import Conv2D
from keras.layers import SeparableConv2D
from keras.layers import MaxPooling2D
from keras.layers import GlobalAveragePooling2D
import tensorflow as tf

input_tensor = tf.ones([1, 224, 224, 3])
input_shape = [224, 224, 3]
img_input = Input(tensor=input_tensor, shape=input_shape)

x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False, name='block1_conv1')(img_input) #(1, 112, 112, 32)
x = BatchNormalization(name='block1_conv1_bn')(x)
x = Activation('relu', name='block1_conv1_act')(x)
x = Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x)  #(1, 109, 109, 64)
x = BatchNormalization(name='block1_conv2_bn')(x)
x = Activation('relu', name='block1_conv2_act')(x)
 
residual = Conv2D(128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual) #(1, 55, 55, 128)
 
x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x)
x = BatchNormalization(name='block2_sepconv1_bn')(x)
x = Activation('relu', name='block2_sepconv2_act')(x)
x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x)
x = BatchNormalization(name='block2_sepconv2_bn')(x)
 
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block2_pool')(x)
x = layers.add([x, residual])
residual = Conv2D(256, (1, 1), strides=(2, 2),
                      padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
 
x = Activation('relu', name='block3_sepconv1_act')(x)
x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x)
x = BatchNormalization(name='block3_sepconv1_bn')(x)
x = Activation('relu', name='block3_sepconv2_act')(x)
x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x)
x = BatchNormalization(name='block3_sepconv2_bn')(x)
 
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block3_pool')(x)
x = layers.add([x, residual])
 
residual = Conv2D(728, (1, 1), strides=(2, 2),
                      padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
x = Activation('relu', name='block4_sepconv1_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x)
x = BatchNormalization(name='block4_sepconv1_bn')(x)
x = Activation('relu', name='block4_sepconv2_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x)
x = BatchNormalization(name='block4_sepconv2_bn')(x)
 
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block4_pool')(x)
x = layers.add([x, residual])
 
for i in range(8):
    residual = x
    prefix = 'block' + str(i + 5)
 
    x = Activation('relu', name=prefix + '_sepconv1_act')(x)
    x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv1')(x)
    x = BatchNormalization(name=prefix + '_sepconv1_bn')(x)
    x = Activation('relu', name=prefix + '_sepconv2_act')(x)
    x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv2')(x)
    x = BatchNormalization(name=prefix + '_sepconv2_bn')(x)
    x = Activation('relu', name=prefix + '_sepconv3_act')(x)
    x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv3')(x)
    x = BatchNormalization(name=prefix + '_sepconv3_bn')(x)
 
    x = layers.add([x, residual])
 
residual = Conv2D(1024, (1, 1), strides=(2, 2),
                      padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
 
x = Activation('relu', name='block13_sepconv1_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x)
x = BatchNormalization(name='block13_sepconv1_bn')(x)
x = Activation('relu', name='block13_sepconv2_act')(x)
x = SeparableConv2D(1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x)
x = BatchNormalization(name='block13_sepconv2_bn')(x)
 
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block13_pool')(x)
x = layers.add([x, residual])
 
x = SeparableConv2D(1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x)
x = BatchNormalization(name='block14_sepconv1_bn')(x)
x = Activation('relu', name='block14_sepconv1_act')(x)
 
x = SeparableConv2D(2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(x)
x = BatchNormalization(name='block14_sepconv2_bn')(x)
x = Activation('relu', name='block14_sepconv2_act')(x)
 
if include_top:
    x = GlobalAveragePooling2D(name='avg_pool')(x)
    x = Dense(classes, activation='softmax', name='predictions')(x)
else:
    if pooling == 'avg':
        x = GlobalAveragePooling2D()(x)
    elif pooling == 'max':
        x = GlobalMaxPooling2D()(x)
 
if input_tensor is not None:
    inputs = get_source_inputs(input_tensor)
 else:
    inputs = img_input
 
model = Model(inputs, x, name='xception')
 
if weights == 'imagenet':
    if include_top:
        weights_path = get_file('xception_weights_tf_dim_ordering_tf_kernels.h5',
                                    TF_WEIGHTS_PATH,
                                    cache_subdir='models')
    else:
        weights_path = get_file('xception_weights_tf_dim_ordering_tf_kernels_notop.h5',
                                    TF_WEIGHTS_PATH_NO_TOP,
                                    cache_subdir='models')
    model.load_weights(weights_path)
 if old_data_format:
    K.set_image_data_format(old_data_format)
return model

[1] 代码参考

上一篇 下一篇

猜你喜欢

热点阅读