keras 之 迁移学习,改变VGG16输出层,用imagene

2018-04-18  本文已影响0人  vola_lei

迁移学习, 用现成网络,跑自己数据: 保留已有网络除输出层以外其它层的权重, 改变已有网络的输出层的输出class 个数. 以已有网络权值为基础, 训练自己的网络,
以keras 2.1.5 / VGG16Net为例.

from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
from keras.models import Sequential
from keras.layers import Dropout, Flatten, Dense
from keras import Model
from keras import initializers
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.applications.vgg16 import VGG16
# prepare data augmentation configuration
train_datagen = ImageDataGenerator(
    rescale=1. / 255,
#     shear_range=0.2,
#     zoom_range=0.2,
#     horizontal_flip=True
    )

test_datagen = ImageDataGenerator(rescale=1. / 255)
# the input is the same as original network
input_shape = (224,224,3)
train_generator = train_datagen.flow_from_directory(
    directory = './data/train/',
    target_size = input_shape[:-1],
    color_mode = 'rgb',
    classes = None,
    class_mode = 'categorical',
    batch_size = 10,
    shuffle = True)

test_generator = test_datagen.flow_from_directory(
    directory = './data/test/',
    target_size = input_shape[:-1],
    batch_size = 10,
    class_mode = 'categorical')
# build the VGG16 network, 加载VGG16网络, 改变输出层的类别数目.
# include_top = True, load the whole network
# set the new output classes to 10
# weights = None, load no weights 
base_model = VGG16(input_shape = input_shape, 
                     include_top = True, 
                     classes = 10, 
                     weights = None
                     )
print('Model loaded.')
base_model.layers[-1].name = 'pred'
base_model.layers[-1].kernel_initializer.get_config()

将会得到:

{'distribution': 'uniform', 'mode': 'fan_avg', 'scale': 1.0, 'seed': None}
base_model.layers[-1].kernel_initializer = initializers.glorot_normal()
base_model.load_weights('./vgg16_weights_tf_dim_ordering_tf_kernels.h5', by_name = True)
# compile the model with a SGD/momentum optimizer
# and a very slow learning rate.
sgd = optimizers.SGD(lr=0.01, decay=1e-4, momentum=0.9, nesterov=True)

base_model.compile(loss = 'categorical_crossentropy',
              optimizer = sgd,
              metrics=['accuracy'])

# fine-tune the model

check = ModelCheckpoint('./', 
                monitor='val_loss', 
                verbose=0, 
                save_best_only=False, 
                save_weights_only=False, 
                mode='auto', 
                period=1)

stop = EarlyStopping(monitor='val_loss',
              min_delta=0, 
              patience=0, 
              verbose=0, 
              mode='auto')

base_model.fit_generator(
    generator = train_generator,
    epochs = 5,
    verbose = 1,
    validation_data = test_generator,
    shuffle = True,
    callbacks = [check, stop]
    
    )
model.save_weights('fine_tuned_net.h5')
上一篇 下一篇

猜你喜欢

热点阅读