Keras ImageDataGenerator 图像数据扩充参

2019-04-27  本文已影响0人  格物致知Lee

在利用图像数据进行深度学习建模的任务中,如果数据集较小,我们需要进行Image Data Augmentation:对已有图片进行平移,剪切,垂直对称等操作形成新的图片。将新图片加入数据集,从而扩充数据集。Keras的内置函数ImageDataGenerator就是用来扩充图像数据集的。下面我们对Keras中的ImageDataGenerator的各项参数进行说明和使用策略。

1.ImageDataGenerator 类参数说明

keras.preprocessing.image.ImageDataGenerator(

featurewise_center=False,   #将输入全部数据的均值设置为 0。一般不用。

samplewise_center=False,    #将每个样本的均值设置为 0。一般不用。

featurewise_std_normalization=False,#将输入除以全部数据标准差。一般不用。

samplewise_std_normalization=False,#将输入除以其标准差。一般不用。

zca_whitening=False,#是否应用 ZCA 白化。

zca_epsilon=1e-06, #ZCA 白化的 epsilon 值。常用。

rotation_range=0,#整数。随机旋转的度数范围。常用。

width_shift_range=0.0,#浮点数,水平平移百分比,不宜太大一般0.1,0.2

height_shift_range=0.0,#浮点数,垂直平移百分比,不宜太大一般0.1,0.2

brightness_range=None,#浮点数,亮度调整。

shear_range=0.0,#浮点数,错切变换角度。

zoom_range=0.0,#浮点数[0,1],随机缩放。[llow,upp]:随机缩放范围。

channel_shift_range=0.0,#浮点数[0.0,255.0],图像上色。

fill_mode='nearest',#边界填充,一般默认。

cval=0.0,#一般不用。

horizontal_flip=False,#水平翻转,常用。

vertical_flip=False,#垂直翻转,看应用场景使用。

rescale=None,#数据缩放,常用:1/255.0。

preprocessing_function=None,

data_format=None,

validation_split=0.0,#验证集划分。常用。

dtype=None

)

2.使用案例

a.使用.flow,传入列表数据进行数据扩充。


import keras

from keras.datasets import cifar10

from keras.preprocessing.image import ImageDataGenerator

#添加数据集

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

#标签向量化

y_train = np_utils.to_categorical(y_train, num_classes)

y_test = np_utils.to_categorical(y_test, num_classes)

#图片生成器

datagen = ImageDataGenerator(

    rotation_range=20,

    width_shift_range=0.2,

    height_shift_range=0.2,

    horizontal_flip=True

)

# fit

datagen.fit(x_train)

# flow

datagen.flow(x_train, y_train, batch_size=32)

b.通过.flow_from_directory(directory)加载文件中的图片并进行扩充。注意:目录下需要有各个类别图像对应的文件夹。如:train文件夹下有cats(里面只有猫的图片),dogs(里面值有狗的图片)文件夹。

image

train_datagen = ImageDataGenerator(

        rescale=1./255,

        shear_range=0.2,

        zoom_range=0.2,

        horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1./255)

#不需要.fit()

train_generator = train_datagen.flow_from_directory(

        'data/train',

        target_size=(32, 32),

        batch_size=32)

validation_generator = test_datagen.flow_from_directory(

        'data/validation',

        target_size=(32, 32),

        batch_size=32)

小伙伴们如果觉得文章还行的请点个赞呦!!同时觉得文章哪里有问题的可以评论一下 谢谢你!

上一篇下一篇

猜你喜欢

热点阅读