TensorFlow 训练 CNN 分类器
前面两篇文章分别介绍了怎么安装 TensorFlow 和怎么使用 TensorFlow 自带的目标检测 API。从这边文章开始介绍怎么使用 TensorFlow 来搭建自己的网络,怎么保存训练好的模型,怎么导入保存的模型来推断,怎么使用更方便 tf.contrib.slim 来训练神经网络等。
本文通过一个简单的分类任务来说明怎么使用 TensorFlow 训练 CNN 模型。
一、简单的 10 分类任务
现在有一个任务,需要训练一个 10 分类器,区分图一中的图像。这些图像都是通过 Python 的一个自动生成验证码的第三方库 captcha (使用 sudo pip/pip3 install captcha 安装)随机生成的,每张图像都包含 0-9 这 10 个数字中的一个,可以看到图像带有很强的背景噪声。图像的大小为 28 x 28 像素,命名规则为 image序号_类标号.jpg
。从图一可以发现,通过肉眼只有第 1 行第 4 张,第 2 行第 4 张,第 5 行第 4 张和最后一行第 4 张图像稍微容易辨认一点。
为了能够训练出一个准确率比较高的分类器,需要准备大量的训练数据,使用如下代码生成 50000 张训练图像:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 22 13:43:34 2018
@author: shirhe-lyh
"""
import cv2
import numpy as np
from captcha.image import ImageCaptcha
def generate_captcha(text='1'):
"""Generate a digit image."""
capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
image = capt.generate_image(text)
image = np.array(image, dtype=np.uint8)
return image
if __name__ == '__main__':
output_dir = './datasets/images/'
for i in range(50000):
label = np.random.randint(0, 10)
image = generate_captcha(str(label))
image_name = 'image{}_{}.jpg'.format(i+1, label)
output_path = output_dir + image_name
cv2.imwrite(output_path, image)
这些图像保存在文件夹 ./datasets/images/ 内,实际执行上述代码时请手动创建该文件夹,或者指定其它文件夹。
二、创建简单的 CNN 模型
众所周知,机器学习/深度学习的算法都有相同的模式(或流程),一般包括数据预处理、预测、后处理和计算损失这几个过程。所以为了一般化使用,可以定义一个抽象类,以后的模型定义都继承自该类。前面已经注意到了,通过 captcha 生成的 50000 张训练图像使用肉眼已经较难区分,因此需要搭建一个比较深的网络,下面创建的网络包括 6 个卷基层和 3 个全连接层,准确率已经可以达到 99% 以上了。
TensorFlow 建立神经网络通过底层的 tf.nn
模块实现,如卷积操作通过函数 tf.nn.conv2d
来实现,池化操作通过函数 tf.nn.max_pool
来实现,而全连接层没有封装的现成函数,需要通过矩阵乘法 tf.matmul
和加法 tf.add
自己实现(以后会介绍创建神经网络更方便的模块 tf.contrib.slim
)。
话不多说,直接上代码(命名为 model.py):
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 16:54:02 2018
@author: shirhe-lyh
"""
import tensorflow as tf
from abc import ABCMeta
from abc import abstractmethod
class BaseModel(object):
"""Abstract base class for any model."""
__metaclass__ = ABCMeta
def __init__(self, num_classes):
"""Constructor.
Args:
num_classes: Number of classes.
"""
self._num_classes = num_classes
@property
def num_classes(self):
return self._num_classes
@abstractmethod
def preprocess(self, inputs):
"""Input preprocessing. To be override by implementations.
Args:
inputs: A float32 tensor with shape [batch_size, height, width,
num_channels] representing a batch of images.
Returns:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, widht, num_channels] representing a batch of images.
"""
pass
@abstractmethod
def predict(self, preprocessed_inputs):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
pass
@abstractmethod
def postprocess(self, prediction_dict, **params):
"""Convert predicted output tensors to final forms.
Args:
prediction_dict: A dictionary holding prediction tensors.
**params: Additional keyword arguments for specific implementations
of specified models.
Returns:
A dictionary containing the postprocessed results.
"""
pass
@abstractmethod
def loss(self, prediction_dict, groundtruth_lists):
"""Compute scalar loss tensors with respect to provided groundtruth.
Args:
prediction_dict: A dictionary holding prediction tensors.
groundtruth_lists: A list of tensors holding groundtruth
information, with one entry for each image in the batch.
Returns:
A dictionary mapping strings (loss names) to scalar tensors
representing loss values.
"""
pass
class Model(BaseModel):
"""A simple 10-classification CNN model definition."""
def __init__(self,
is_training,
num_classes):
"""Constructor.
Args:
is_training: A boolean indicating whether the training version of
computation graph should be constructed.
num_classes: Number of classes.
"""
super(Model, self).__init__(num_classes=num_classes)
self._is_training = is_training
def preprocess(self, inputs):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
preprocessed_inputs = tf.to_float(inputs)
preprocessed_inputs = tf.subtract(preprocessed_inputs, 128.0)
preprocessed_inputs = tf.div(preprocessed_inputs, 128.0)
return preprocessed_inputs
def predict(self, preprocessed_inputs):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
shape = preprocessed_inputs.get_shape().as_list()
height, width, num_channels = shape[1:]
conv1_weights = tf.get_variable(
'conv1_weights', shape=[3, 3, num_channels, 32],
dtype=tf.float32)
conv1_biases = tf.get_variable(
'conv1_biases', shape=[32], dtype=tf.float32)
conv2_weights = tf.get_variable(
'conv2_weights', shape=[3, 3, 32, 32], dtype=tf.float32)
conv2_biases = tf.get_variable(
'conv2_biases', shape=[32], dtype=tf.float32)
conv3_weights = tf.get_variable(
'conv3_weights', shape=[3, 3, 32, 64], dtype=tf.float32)
conv3_biases = tf.get_variable(
'conv3_biases', shape=[64], dtype=tf.float32)
conv4_weights = tf.get_variable(
'conv4_weights', shape=[3, 3, 64, 64], dtype=tf.float32)
conv4_biases = tf.get_variable(
'conv4_biases', shape=[64], dtype=tf.float32)
conv5_weights = tf.get_variable(
'conv5_weights', shape=[3, 3, 64, 128], dtype=tf.float32)
conv5_biases = tf.get_variable(
'conv5_biases', shape=[128], dtype=tf.float32)
conv6_weights = tf.get_variable(
'conv6_weights', shape=[3, 3, 128, 128], dtype=tf.float32)
conv6_biases = tf.get_variable(
'conv6_biases', shape=[128], dtype=tf.float32)
flat_height = height // 4
flat_width = width // 4
flat_size = flat_height * flat_width * 128
fc7_weights = tf.get_variable(
'fc7_weights', shape=[flat_size, 512], dtype=tf.float32)
fc7_biases = tf.get_variable(
'f7_biases', shape=[512], dtype=tf.float32)
fc8_weights = tf.get_variable(
'fc8_weights', shape=[512, 512], dtype=tf.float32)
fc8_biases = tf.get_variable(
'f8_biases', shape=[512], dtype=tf.float32)
fc9_weights = tf.get_variable(
'fc9_weights', shape=[512, self.num_classes], dtype=tf.float32)
fc9_biases = tf.get_variable(
'f9_biases', shape=[self.num_classes], dtype=tf.float32)
net = preprocessed_inputs
net = tf.nn.conv2d(net, conv1_weights, strides=[1, 1, 1, 1],
padding='SAME')
net = tf.nn.relu(tf.nn.bias_add(net, conv1_biases))
net = tf.nn.conv2d(net, conv2_weights, strides=[1, 1, 1, 1],
padding='SAME')
net = tf.nn.relu(tf.nn.bias_add(net, conv2_biases))
net = tf.nn.max_pool(net, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
padding='SAME')
net = tf.nn.conv2d(net, conv3_weights, strides=[1, 1, 1, 1],
padding='SAME')
net = tf.nn.relu(tf.nn.bias_add(net, conv3_biases))
net = tf.nn.conv2d(net, conv4_weights, strides=[1, 1, 1, 1],
padding='SAME')
net = tf.nn.relu(tf.nn.bias_add(net, conv4_biases))
net = tf.nn.max_pool(net, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
padding='SAME')
net = tf.nn.conv2d(net, conv5_weights, strides=[1, 1, 1, 1],
padding='SAME')
net = tf.nn.relu(tf.nn.bias_add(net, conv5_biases))
net = tf.nn.conv2d(net, conv6_weights, strides=[1, 1, 1, 1],
padding='SAME')
net = tf.nn.relu(tf.nn.bias_add(net, conv6_biases))
net = tf.reshape(net, shape=[-1, flat_size])
net = tf.nn.relu(tf.add(tf.matmul(net, fc7_weights), fc7_biases))
net = tf.nn.relu(tf.add(tf.matmul(net, fc8_weights), fc8_biases))
net = tf.add(tf.matmul(net, fc9_weights), fc9_biases)
prediction_dict = {'logits': net}
return prediction_dict
def postprocess(self, prediction_dict):
"""Convert predicted output tensors to final forms.
Args:
prediction_dict: A dictionary holding prediction tensors.
**params: Additional keyword arguments for specific implementations
of specified models.
Returns:
A dictionary containing the postprocessed results.
"""
logits = prediction_dict['logits']
logits = tf.nn.softmax(logits)
classes = tf.cast(tf.argmax(logits, axis=1), dtype=tf.int32)
postprecessed_dict = {'classes': classes}
return postprecessed_dict
def loss(self, prediction_dict, groundtruth_lists):
"""Compute scalar loss tensors with respect to provided groundtruth.
Args:
prediction_dict: A dictionary holding prediction tensors.
groundtruth_lists: A list of tensors holding groundtruth
information, with one entry for each image in the batch.
Returns:
A dictionary mapping strings (loss names) to scalar tensors
representing loss values.
"""
logits = prediction_dict['logits']
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=groundtruth_lists))
loss_dict = {'loss': loss}
return loss_dict
以上代码中,首先定义了一个基类 BaseModel
,这个类声明了几个常用的抽象函数,如数据预处理函数 preprocess
,模型定义及预测函数 predict
,预测结果后处理函数 postprocess
以及损失计算函数 loss
。接下来,定义了真正的 CNN 模型类 Model(BaseModel)
,它继承自基类 BaseModel
,里面的成员函数实现了基类的所有抽象函数。其中,最重要的 CNN 模型定义在 predict
函数中。首先使用函数 tf.get_variable
声明各种权重和偏置变量,该函数在变量没有定义时会创建变量,如果变量已经定义好了则会获取该变量的值,正可用于变量迭代过程。注意:一般情况下,不需要改动 tf.get_variable
的默认初始化方式,否则很容易造成定义的模型无法优化(即损失不下降。读者可以指定参数
initializer=tf.truncated_normal_initializer(stddev=x)
来试几次)。然后定义了 6 个卷积层,其中激活函数都使用整流线性单元 tf.nn.relu
,还做了两次最大池化 tf.nn.max_pool
。最后,接了三个全连接层:tf.add(tf.matmul(·, ·), ·)
,最终的结果通过一个字典返回,其中最后一个全连接层没有作用任何激活函数。
其它三个函数都是一目了然的。图像预处理将数据正规化到 [-1, 1] 之间,后处理函数作用 tf.nn.softmax
函数之后直接选取最大概率的类标号,而损失则使用交叉熵损失。
以上大部分结果因为一般化处理的缘故,使用了字典作为返回对象,如觉得多此一举可自行更改。纵观整个 model.py 文件,代码量最多的是 CNN 模型定义部分的函数 predict
,该函数前一部分手动定义了许多变量,后一部分手动堆叠了许多卷积、全连接层。后续文章我们直接使用更方便、更紧凑的工具 tensorflow.contrib.slim
,将使得模型定义更直观、更简洁,甚至比 Keras 更节省代码量。
3.训练
模型定义好之后,接着就只剩读入数据进行模型训练了。因为我们总共只生成了 50000 张 28 x 28 像素的小图像,所以不占用多少内存,下面将采用一次性导入的方式读入所有图像,之后则每次从中随机的采样一个批量用于训练。
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 19:27:44 2018
@author: shirhe-lyh
"""
"""Train a CNN model to classifying 10 digits.
Example Usage:
---------------
python3 train.py \
--images_path: Path to the training images (directory).
--model_output_path: Path to model.ckpt.
"""
import cv2
import glob
import numpy as np
import os
import tensorflow as tf
import model
flags = tf.app.flags
flags.DEFINE_string('images_path', None, 'Path to training images.')
flags.DEFINE_string('model_output_path', None, 'Path to model checkpoint.')
FLAGS = flags.FLAGS
def get_train_data(images_path):
"""Get the training images from images_path.
Args:
images_path: Path to trianing images.
Returns:
images: A list of images.
lables: A list of integers representing the classes of images.
Raises:
ValueError: If images_path is not exist.
"""
if not os.path.exists(images_path):
raise ValueError('images_path is not exist.')
images = []
labels = []
images_path = os.path.join(images_path, '*.jpg')
count = 0
for image_file in glob.glob(images_path):
count += 1
if count % 100 == 0:
print('Load {} images.'.format(count))
image = cv2.imread(image_file)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Assume the name of each image is imagexxx_label.jpg
label = int(image_file.split('_')[-1].split('.')[0])
images.append(image)
labels.append(label)
images = np.array(images)
labels = np.array(labels)
return images, labels
def next_batch_set(images, labels, batch_size=128):
"""Generate a batch training data.
Args:
images: A 4-D array representing the training images.
labels: A 1-D array representing the classes of images.
batch_size: An integer.
Return:
batch_images: A batch of images.
batch_labels: A batch of labels.
"""
indices = np.random.choice(len(images), batch_size)
batch_images = images[indices]
batch_labels = labels[indices]
return batch_images, batch_labels
def main(_):
inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 3], name='inputs')
labels = tf.placeholder(tf.int32, shape=[None], name='labels')
cls_model = model.Model(is_training=True, num_classes=10)
preprocessed_inputs = cls_model.preprocess(inputs)
prediction_dict = cls_model.predict(preprocessed_inputs)
loss_dict = cls_model.loss(prediction_dict, labels)
loss = loss_dict['loss']
postprocessed_dict = cls_model.postprocess(prediction_dict)
classes = postprocessed_dict['classes']
classes_ = tf.identity(classes, name='classes')
acc = tf.reduce_mean(tf.cast(tf.equal(classes, labels), 'float'))
global_step = tf.Variable(0, trainable=False)
learning_rate = tf.train.exponential_decay(0.1, global_step, 150, 0.9)
optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9)
train_step = optimizer.minimize(loss, global_step)
saver = tf.train.Saver()
images, targets = get_train_data(FLAGS.images_path)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for i in range(6000):
batch_images, batch_labels = next_batch_set(images, targets)
train_dict = {inputs: batch_images, labels: batch_labels}
sess.run(train_step, feed_dict=train_dict)
loss_, acc_ = sess.run([loss, acc], feed_dict=train_dict)
train_text = 'step: {}, loss: {}, acc: {}'.format(
i+1, loss_, acc_)
print(train_text)
saver.save(sess, FLAGS.model_output_path)
if __name__ == '__main__':
tf.app.run()
以上代码(命名为 train.py)中的函数 get_train_data
从图像文件夹中一次性读入所有的训练数据,因为 OpenCV 读取图像的通道顺序为 BGR,所以需要额外的调整为 RGB 顺序。首次在终端运行 train.py 时,会因为执行这个函数而运行较长时间,第二次及以后则由于缓存的缘故会很快执行到训练部分。之后的函数 next_batch_set
的作用是随机的挑选出一个批量,使用 numpy 的函数 np.random.choice
很容易达到这个目的。
下面进入到训练部分。首先,通过 tf.placeholder
定义了两个占位符,作为训练数据的入口。然后,实例化类 Model 的一个对象 cls_model(is_training 变量没有作用,可从 model.py 中删除),分部执行数据预处理、预测、计算损失、计算准确率等,其中多余的一行:
classes_ = tf.identity(classes, name='classes')
用来指定数据出口,方便模型调用。接下来则指定了优化算法为 tf.train.MomentumOptimizer
,通过实验确定了初始学习率和衰减步长。学习率需要多试几次才能让模型更好的收敛到一个准确率较高的解,可以选取 0.001, 0.01, 0.1, 0.0001, 0.00001 等测试,看损失(准确率)是否在迭代100-200次之后下降(上升),之后再微调学习率的衰减步长和动量参数。最后,则通过一个循环对模型进行迭代训练,训练结束之后保存训练参数。
实际执行时,在 train.py 的目录终端下执行:
python3 train.py --images_path /home/.../datasets/images \
--model_output_path /home/.../model.ckpt
即分别指定训练数据文件夹路径和模型保存路径(模型保存名称前缀)。训练完成之后,会在模型保存路径下生成 model.ckpt.data-00000-of-00001, model.ckpt.index 等文件。
4.模型测试
此时,你大概想测试一下训练的模型在未见数据集上的效果如何了。那么,运行如下代码(命名为 evaluate.py):
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 2 14:02:05 2018
@author: shirhe-lyh
"""
import numpy as np
import tensorflow as tf
from captcha.image import ImageCaptcha
flags = tf.app.flags
flags.DEFINE_string('model_ckpt_path', None, 'Path to model checkpoint.')
FLAGS = flags.FLAGS
def generate_captcha(text='1'):
"""Generate a digit image."""
capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
image = capt.generate_image(text)
image = np.array(image, dtype=np.uint8)
return image
def main(_):
with tf.Session() as sess:
ckpt_path = FLAGS.model_ckpt_path
saver = tf.train.import_meta_graph(ckpt_path + '.meta')
saver.restore(sess, ckpt_path)
inputs = tf.get_default_graph().get_tensor_by_name('inputs:0')
classes = tf.get_default_graph().get_tensor_by_name('classes:0')
for i in range(10):
label = np.random.randint(0, 10)
image = generate_captcha(str(label))
image_np = np.expand_dims(image, axis=0)
predicted_label = sess.run(classes,
feed_dict={inputs: image_np})
print(predicted_label, ' vs ', label)
if __name__ == '__main__':
tf.app.run()
python3 evaluate.py --model_ckpt_path /home/.../model.ckpt
你将看到你的训练成果。上述代码首先通过 tf.train.import_meta_graph
和 .restore
函数将保存的模型导入,然后通过张量名获取模型的数据入口和数据出口,之后便可以对输入图像做推断了。
下一篇文章将介绍模型的保存和导入,敬请关注。