深度学习TF深度学习

TensorFlow 使用预训练模型 ResNet-50

2018-04-29  本文已影响2325人  公输睚信

        前面的文章已经说明了怎么使用 TensorFlow 来构建、训练、保存、导出模型等,现在来说明怎么使用 TensorFlow 调用预训练模型来精调神经网络。为了简单起见,以调用预训练的 ResNet-50 用于图像分类为例,使用的模块仍然是 tf.contrib.slim

        TensorFlow 的所有用于图像分类的预训练模型的下载地址为 models/research/slim,包含常用的 VGG,Inception,ResNet,MobileNet 以及最新的 NasNet 模型等。要使用这些预训练模型的关键是将这些预训练的参数正确的导入到定义好的神经网络,这可以通过函数 slim.assign_from_checkpoint_fn 来方便的实现。下面,用代码来说明。

一、Fine tuning 模型定义

        前已提及,TensorFlow 所有预训练模型均在 GitHub 项目 models/research/slim,而其对应的神经网络实现则在其子文件夹 nets。我们以调用 ResNet-50 为例(其它模型类似),首先来定义网络结构:

import tensorflow as tf

from tensorflow.contrib.slim import nets

slim = tf.contrib.slim


def predict(self, preprocessed_inputs):
    """Predict prediction tensors from inputs tensor.

    Outputs of this function can be passed to loss or postprocess functions.

    Args:
        preprocessed_inputs: A float32 tensor with shape [batch_size,
            height, width, num_channels] representing a batch of images.
            
    Returns:
        prediction_dict: A dictionary holding prediction tensors to be
            passed to the Loss or Postprocess functions.
    """
    net, endpoints = nets.resnet_v1.resnet_v1_50(
        preprocessed_inputs, num_classes=None,
        is_training=self._is_training)
    net = tf.squeeze(net, axis=[1, 2])
    net = slim.fully_connected(net, num_outputs=self.num_classes,
                               activation_fn=None, scope='Predict')
    prediction_dict = {'logits': net}
    return prediction_dict

        我们假设要分类的图像有 self.num_classes 个类,随机选择一个批量的图像,对这些图像进行预处理后,把它们作为参数传入 predict 函数,此时直接调用 TensorFlow-slim 封装好的 nets.resnet_v1.resnet_v1_50 神经网络得到图像特征,因为 ResNet-50 是用于 1000 个类的分类的,所以需要设置参数 num_classes=None 禁用它的最后一个输出层。我们假设输入的图像批量形状为 [None, 224, 224, 3],则 resnet_v1_50 函数返回的形状为 [None, 1, 1, 2048],为了输入到全连接层,需要用函数 tf.squeeze 去掉形状为 1 的第 1,2 个索引维度。最后,连接再一个全连接层得到 self.num_classes 个类的预测输出。

        可以看到,使用 tf.contrib.slim 模块,调用 ResNet-50 等神经网络变得异常简单。而接下来的关键问题是怎么导入预训练的参数,进而使用我们自己的数据来对预训练模型进行精调。在阐述怎么解决这个问题之前,先将整个模型定义的文件 model.py 列出以方便阅读:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 16:54:02 2018

@author: shirhe-lyh
"""

import tensorflow as tf

from abc import ABCMeta
from abc import abstractmethod
from tensorflow.contrib.slim import nets

slim = tf.contrib.slim


class BaseModel(object):
    """Abstract base class for any model."""
    __metaclass__ = ABCMeta
    
    def __init__(self, num_classes):
        """Constructor.
        
        Args:
            num_classes: Number of classes.
        """
        self._num_classes = num_classes
        
    @property
    def num_classes(self):
        return self._num_classes
    
    @abstractmethod
    def preprocess(self, inputs):
        """Input preprocessing. To be override by implementations.
        
        Args:
            inputs: A float32 tensor with shape [batch_size, height, width,
                num_channels] representing a batch of images.
            
        Returns:
            preprocessed_inputs: A float32 tensor with shape [batch_size, 
                height, widht, num_channels] representing a batch of images.
        """
        pass
    
    @abstractmethod
    def predict(self, preprocessed_inputs):
        """Predict prediction tensors from inputs tensor.
        
        Outputs of this function can be passed to loss or postprocess functions.
        
        Args:
            preprocessed_inputs: A float32 tensor with shape [batch_size,
                height, width, num_channels] representing a batch of images.
            
        Returns:
            prediction_dict: A dictionary holding prediction tensors to be
                passed to the Loss or Postprocess functions.
        """
        pass
    
    @abstractmethod
    def postprocess(self, prediction_dict, **params):
        """Convert predicted output tensors to final forms.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            **params: Additional keyword arguments for specific implementations
                of specified models.
                
        Returns:
            A dictionary containing the postprocessed results.
        """
        pass
    
    @abstractmethod
    def loss(self, prediction_dict, groundtruth_lists):
        """Compute scalar loss tensors with respect to provided groundtruth.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            groundtruth_lists: A list of tensors holding groundtruth
                information, with one entry for each image in the batch.
                
        Returns:
            A dictionary mapping strings (loss names) to scalar tensors
                representing loss values.
        """
        pass
    
        
class Model(BaseModel):
    """xxx definition."""
    
    def __init__(self, is_training, num_classes):
        """Constructor.
        
        Args:
            is_training: A boolean indicating whether the training version of
                computation graph should be constructed.
            num_classes: Number of classes.
        """
        super(Model, self).__init__(num_classes=num_classes)
        
        self._is_training = is_training
        
    def preprocess(self, inputs):
        """Predict prediction tensors from inputs tensor.
        
        Outputs of this function can be passed to loss or postprocess functions.
        
        Args:
            preprocessed_inputs: A float32 tensor with shape [batch_size,
                height, width, num_channels] representing a batch of images.
            
        Returns:
            prediction_dict: A dictionary holding prediction tensors to be
                passed to the Loss or Postprocess functions.
        """
        channel_means = [123.68, 116.779, 103.939]
        preprocessed_inputs = tf.to_float(inputs)
        preprocessed_inputs = preprocessed_inputs - [[channel_means]]
        return preprocessed_inputs
    
    def predict(self, preprocessed_inputs):
        """Predict prediction tensors from inputs tensor.
        
        Outputs of this function can be passed to loss or postprocess functions.
        
        Args:
            preprocessed_inputs: A float32 tensor with shape [batch_size,
                height, width, num_channels] representing a batch of images.
            
        Returns:
            prediction_dict: A dictionary holding prediction tensors to be
                passed to the Loss or Postprocess functions.
        """
        net, endpoints = nets.resnet_v1.resnet_v1_50(
            preprocessed_inputs, num_classes=None,
            is_training=self._is_training)
        net = tf.squeeze(net, axis=[1, 2])
        logits = slim.fully_connected(net, num_outputs=self.num_classes,
                                      activation_fn=None, scope='Predict')
        prediction_dict = {'logits': logits}
        return prediction_dict
    
    def postprocess(self, prediction_dict):
        """Convert predicted output tensors to final forms.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            **params: Additional keyword arguments for specific implementations
                of specified models.
                
        Returns:
            A dictionary containing the postprocessed results.
        """
        logits = prediction_dict['logtis']
        logits = tf.nn.softmax(logits)
        classes = tf.argmax(logits, axis=1)
        postprocessed_dict = {'classes': classes}
        return postprocessed_dict
    
    def loss(self, prediction_dict, groundtruth_lists):
        """Compute scalar loss tensors with respect to provided groundtruth.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            groundtruth_lists_dict: A dict of tensors holding groundtruth
                information, with one entry for each image in the batch.
                
        Returns:
            A dictionary mapping strings (loss names) to scalar tensors
                representing loss values.
        """
        logits = prediction_dict['logtis']
        slim.losses.sparse_softmax_cross_entropy(
            logits=logits, 
            labels=groundtruth_lists,
            scope='Loss')
        loss = slim.losses.get_total_loss()
        loss_dict = {'loss': loss}
        return loss_dict
        
    def accuracy(self, postprocessed_dict, groundtruth_lists):
        """Calculate accuracy.
        
        Args:
            postprocessed_dict: A dictionary containing the postprocessed 
                results
            groundtruth_lists: A dict of tensors holding groundtruth
                information, with one entry for each image in the batch.
                
        Returns:
            accuracy: The scalar accuracy.
        """
        classes = postprocessed_dict['classes']
        accuracy = tf.reduce_mean(
            tf.cast(tf.equal(classes, groundtruth_lists), dtype=tf.float32))
        return accuracy

二、预训练模型导入

        要将预训练模型 ResNet-50 的参数导入到前面定义好的模型,需要继续借助 tf.contrib.slim 模块,而且方法很简单,只需要在训练函数 slim.learning.train 中指定初始化参数来源函数 init_fn 即可,而这可以通过函数

slim.assign_from_checkpoint_fn(model_path, var_list,
                               ignore_missing_vars=False,
                               reshape_variables=False)

很方便的实现。其中,第一个参数 model_path 指定预训练模型 xxx.ckpt 文件的路径,第二个参数 var_list 指定需要导入对应预训练参数的所有变量,通过函数

slim.get_variables_to_restore(include=None,
                              exclude=None)

可以快速指定,如果需要排除一些变量,也就是如果想让某些变量随机初始化而不是直接使用预训练模型来初始化,则直接在参数 exclude 中指定即可。第三个参数 ignore_missing_vars 非常重要,一定要将其设置为 True,也就是说,一定要忽略那些在定义的模型结构中可能存在的而在预训练模型中没有的变量,因为如果自己定义的模型结构中存在一个参数,而这些参数在预训练模型文件 xxx.ckpt 中没有,那么如果不忽略的话,就会导入失败(这样的变量很多,比如卷积层的偏置项 bias,一般预训练模型中没有,所以需要忽略,即使用默认的零初始化)。最后一个参数 reshape_variabels 指定对某些变量进行变形,这个一般用不到,使用默认的 False 即可。

        有了以上的基础,而且你还阅读过上一篇文章 TensorFlow-slim 训练 CNN 分类模型(续) 的话,那么整个使用预训练模型的训练文件 train.py 就很容易写出了,如下(重点在最后几行):

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 19:27:44 2018

@author: shirhe-lyh
"""

"""Train a CNN classification model via pretrained ResNet-50 model.

Example Usage:
---------------
python3 train.py \
    --resnet50_model_path: Path to pretrained ResNet-50 model.
    --record_path: Path to training tfrecord file.
    --logdir: Path to log directory.
"""

import tensorflow as tf

import model

slim = tf.contrib.slim
flags = tf.app.flags

flags.DEFINE_string('record_path', None, 'Path to training tfrecord file.')
flags.DEFINE_string('resnet50_model_path', None, 
                    'Path to pretrained ResNet-50 model.')
flags.DEFINE_string('logdir', None, 'Path to log directory.')
FLAGS = flags.FLAGS


def get_record_dataset(record_path,
                       reader=None, image_shape=[224, 224, 3], 
                       num_samples=50000, num_classes=10):
    """Get a tensorflow record file.
    
    Args:
        
    """
    if not reader:
        reader = tf.TFRecordReader
        
    keys_to_features = {
        'image/encoded': 
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': 
            tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/class/label': 
            tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1], 
                               dtype=tf.int64))}
        
    items_to_handlers = {
        'image': slim.tfexample_decoder.Image(shape=image_shape, 
                                              #image_key='image/encoded',
                                              #format_key='image/format',
                                              channels=3),
        'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[])}
    
    decoder = slim.tfexample_decoder.TFExampleDecoder(
        keys_to_features, items_to_handlers)
    
    labels_to_names = None
    items_to_descriptions = {
        'image': 'An image with shape image_shape.',
        'label': 'A single integer between 0 and 9.'}
    return slim.dataset.Dataset(
        data_sources=record_path,
        reader=reader,
        decoder=decoder,
        num_samples=num_samples,
        num_classes=num_classes,
        items_to_descriptions=items_to_descriptions,
        labels_to_names=labels_to_names)


def main(_):
    dataset = get_record_dataset(FLAGS.record_path, num_samples=79573, 
                                 num_classes=54)
    data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
    image, label = data_provider.get(['image', 'label'])
    
    # Data augumentation
    image = tf.image.random_flip_left_right(image)
        
    inputs, labels = tf.train.batch([image, label],
                                    batch_size=64,
                                    allow_smaller_final_batch=True)
    
    cls_model = model.Model(is_training=True)
    preprocessed_inputs = cls_model.preprocess(inputs)
    prediction_dict = cls_model.predict(preprocessed_inputs)
    loss_dict = cls_model.loss(prediction_dict, labels)
    loss = loss_dict['loss']
    postprocessed_dict = cls_model.postprocess(prediction_dict)
    acc = cls_model.accuracy(postprocessed_dict, labels)
    tf.summary.scalar('loss', loss)
    tf.summary.scalar('accuracy', acc)
    

    #optimizer = tf.train.MomentumOptimizer(learning_rate=0.001, momentum=0.99)
    optimizer = tf.train.AdamOptimizer(learning_rate=0.0001)
    train_op = slim.learning.create_train_op(loss, optimizer,
                                             summarize_gradients=True)
    
    variables_to_restore = slim.get_variables_to_restore()
    init_fn = slim.assign_from_checkpoint_fn(FLAGS.resnet50_model_path,
                                             variables_to_restore,
                                             ignore_missing_vars=True)
    
    slim.learning.train(train_op=train_op, logdir=FLAGS.logdir, 
                        init_fn=init_fn, 
                        save_summaries_secs=20, 
                        save_interval_secs=600)
    
if __name__ == '__main__':
    tf.app.run()

预告:下一篇文章将要介绍如何用 TensorFlow 来训练多任务多标签模型,敬请期待!

上一篇下一篇

猜你喜欢

热点阅读