图像处理深度学习

如何用小样本训练高性能深度网络

2017-10-19  本文已影响0人  苟且偷生小屁屁

本文借鉴http://blog.csdn.net/caanyee/article/details/52502759,自学使用.

数据预处理与数据提升

为了尽量利用我们有限的训练数据,我们将通过一系列随机变换对数据进行提升,这样我们的模型将看不到任何两张完全相同的图片,这有利于我们抑制过拟合,使得模型的泛化能力更好。

在Keras中,这个步骤可以通过keras.preprocessing.image.ImageGenerator来实现

keras.preprocessing.image.ImageDataGenerator(featurewise_center=False,
    samplewise_center=False,
    featurewise_std_normalization=False,
    samplewise_std_normalization=False,
    zca_whitening=False,
    zca_epsilon=1e-6,
    rotation_range=0.,
    width_shift_range=0.,
    height_shift_range=0.,
    shear_range=0.,
    zoom_range=0.,
    channel_shift_range=0.,
    fill_mode='nearest',
    cval=0.,
    horizontal_flip=False,
    vertical_flip=False,
    rescale=None,
    preprocessing_function=None,
    data_format=K.image_data_format())

用以生成一个batch的图像数据,支持实时数据提升。训练时该函数会无限生成数据,直到达到规定的epoch次数为止。

参数

featurewise_center:布尔值,使输入数据集去中心化(均值为0), 按feature执行,默认false

samplewise_center:布尔值,使输入数据的每个样本均值为0,默认false

featurewise_std_normalization:布尔值,将输入除以数据集的标准差以完成标准化, 按feature执行

samplewise_std_normalization:布尔值,将输入的每个样本除以其自身的标准差

zca_whitening:布尔值,对输入数据施加ZCA白化

zca_epsilon: ZCA使用的eposilon,默认1e-6

rotation_range:整数,数据提升时图片随机转动的角度

width_shift_range:浮点数,图片宽度的某个比例,数据提升时图片水平偏移的幅度

height_shift_range:浮点数,图片高度的某个比例,数据提升时图片竖直偏移的幅度

shear_range:浮点数,剪切强度(逆时针方向的剪切变换角度)

zoom_range:浮点数或形如[lower,upper]的列表,随机缩放的幅度,若为浮点数,则相当于[lower,upper] = [1 - zoom_range, 1+zoom_range]

channel_shift_range:浮点数,随机通道偏移的幅度

fill_mode:;‘constant’,‘nearest’,‘reflect’或‘wrap’之一,当进行变换时超出边界的点将根据本参数给定的方法进行处理

cval:浮点数或整数,当fill_mode=constant时,指定要向超出边界的点填充的值

horizontal_flip:布尔值,进行随机水平翻转

vertical_flip:布尔值,进行随机竖直翻转

rescale: 重放缩因子,默认为None. 如果为None或0则不进行放缩,否则会将该数值乘到数据上(在应用其他变换之前)

preprocessing_function: 将被应用于每个输入的函数。该函数将在任何其他修改之前运行。该函数接受一个参数,为一张图片(秩为3的numpy array),并且输出一个具有相同shape的numpy array

data_format:字符串,“channel_first”或“channel_last”之一,代表图像的通道维的位置。该参数是Keras 1.x中的image_dim_ordering,“channel_last”对应原本的“tf”,“channel_first”对应原本的“th”。以128x128的RGB图像为例,“channel_first”应将数据组织为(3,128,128),而“channel_last”应将数据组织为(128,128,3)。该参数的默认值是~/.keras/keras.json中设置的值,若从未设置过,则为“channel_last”

举个例子:

datagen = ImageDataGenerator(
        rotation_range=40, (随机旋转40度)
        width_shift_range=0.2,(图片水平偏移20%)
        height_shift_range=0.2,(图片垂直偏移20%)
        rescale=1./255,(归一化,缩放至1/255,浮点数,数值乘以数据)
        shear_range=0.2,(逆时针方向剪切变换角度)
        zoom_range=0.2,(随机缩放,[lower,upper],浮点数,[lower,upper]=[1-浮点数,1+浮点数])
        horizontal_flip=True,(随机水平翻转)
        fill_mode='nearest'(超出边界时怎么处理))

附带一个错切的程序

import cv
import math

def Warp(image, angle):
    a = math.tan(angle * math.pi / 180.0)
    W = image.width
    H = int(image.height + W * a)
    size = (W, H)
    iWarp = cv.CreateImage(size, image.depth, image.nChannels)
    for i in range(image.height):
        for j in range(image.width):
            x = int(i + j * a)
            iWarp[x, j] = image[i, j]
    return iWarp

image = cv.LoadImage('data/train/cat.7.jpg', 1)
iWarp1 = Warp(image, 15)
cv.ShowImage('image', image)
cv.ShowImage('1', iWarp1)
cv.WaitKey(0)

数据提升是对抗过拟合问题的一个武器,但还不够,因为提升过的数据仍然是高度相关的。对抗过拟合的你应该主要关注的是模型的“熵容量”——模型允许存储的信息量。能够存储更多信息的模型能够利用更多的特征取得更好的性能,但也有存储不相关特征的风险。另一方面,只能存储少量信息的模型会将存储的特征主要集中在真正相关的特征上,并有更好的泛化性能。

有很多不同的方法来调整模型的“熵容量”,常见的一种选择是调整模型的参数数目,即模型的层数和每层的规模。另一种方法是对权重进行正则化约束,如L1或L2.这种约束会使模型的权重偏向较小的值。

在我们的模型里,我们使用了很小的卷积网络,只有很少的几层,每层的滤波器数目也不多。再加上数据提升和Dropout,就差不多了。Dropout通过防止一层看到两次完全一样的模式来防止过拟合,相当于也是一种数据提升的方法。(你可以说dropout和数据提升都在随机扰乱数据的相关性)

我们再来回顾一下数据提升的用法, ImageDataGenerator

from keras.preprocessing.image import ImageDataGenerator

然后对训练集,验证集进行数据提升

train_datagen = ImageDataGenerator(
rescale = 1./255, (归一化,这个不可少)
shear_range = 0.2, (错切,正的话就是逆向)
zoom_range = 0.2, (随机缩放,[1-0.2,1+0.2])
horizontal_flip = True (横向的翻转))
validation_datagen = ImageDataGenerator(
rescale = 1./255(归一化,因为是验证集,没有必要进行其他的数据提升处理))
train_generator = train_datagen.flow_from_directory(
**directory=train_data_dir(这个参数最重要,指定提升数据的来源)**,
target_size = 整数tuple,默认为(256, 256). 图像将被resize成该尺寸,
color_mode = 颜色模式,为"grayscale","rgb"之一,默认为"rgb",
batch_size =batch数据的大小,默认32,
shuffle =  是否打乱数据,默认为True,
class_mode="categorical", "binary", "sparse"或None之一. 默认为"categorical. 
该参数决定了返回的标签数组的形式, 
"categorical"会返回2D的one-hot编码标签,
"binary"返回1D的二值标签.
"sparse"返回1D的整数标签,
如果为None则不返回任何标签, 
生成器将仅仅生成batch数据, 这种情况在使用model.predict_generator()和
model.evaluate_generator()等函数时会用到.
)

class_mode是很重要的,返回标签,如果是categorical,那么就是one-hot型,如果是binary,那么就是0,1,如果是sparse那么就是1或者8这样的.

validation_generator = validation_datagen.flow_from_direction(
direction = validation_data_dir,
target_size = (img_width,img_height),
batch_size=batch_size,
color_mode = 'rgb',
class_mode = 'binary')

再次强调一下class_mode的重要性,因为ImageDataGenerator这种图像的提升是对所有类别的图像进行的(包括在directory中),而送入train_generator中以及fit_generator中时是不分训练集和标签集的,所以保存图像的时候要按类别分别保存在不同的文件夹中,文件名可以无所谓.有几个文件夹就可以分几类,像猫狗大战这个问题,由于是2分类问题,所以class_mode可以是binary,而在二分类时,model.compile可以选用的loss可以是'binary_crossentropy',还可以是别的吗,loss还要再想一想.

然后就可以训练模型了

model.fit_generator(
generator = train_generator,
step_per_epoch = nb_train_samples//batch_size(每个epoch有多少步),
epochs=EPOCHS,
verbose=1,
validation_dataz = validation_generator,
validation_steps = nb_validation_samples//batch_size,
callbacks = [callbacks](Tensorboard.ModelCheckpoint...))
  • 最后进行一个总结:
    使用数据提升需要用到
keras.preprocessing.image.ImageDataGenerator

所以第一步:

from keras.preprocessing.image import ImageDataGenerator

第二步,写出提升训练集的类

train_datagen = ImageDataGenerator(
scale = 1./255(这个一般要有)
shear_range = 0.2,(这个是错切,其实是仿射变化)
zoom_range = 0.2,
horizontal_flip = True)

第三步,提升训练集类的实体化

traindata_generator = train_datagen.flow_from_direction
(direction = train_data_dir,
color_mode = 'rgb',
target_size=(img_width,img_height),
class_mode = 'binary')

第四步,训练

model.fit_generator
(generator = traindata_generator,
step_per_epoch = nb_train_samples//batch_size,
epochs = EPOCHS,
verbose = 1,
validation_data =validation_generator,
validation_steps = nb_validation_samples//batch_size,
callbacks=[callback] )

最后,附上几张数据提升猫猫的图像

图片.png

最后给出训练结果,最好的val_acc已经达到0.9399

Firefox_Screenshot_2017-10-19T01-24-58.288Z.png Firefox_Screenshot_2017-10-19T01-24-40.874Z.png Firefox_Screenshot_2017-10-19T01-24-26.805Z.png Firefox_Screenshot_2017-10-19T01-24-10.142Z.png
上一篇下一篇

猜你喜欢

热点阅读