深度学习我爱编程

TensorFlow 自定义生成 .record 文件

2018-04-07  本文已影响268人  公输睚信

一、生成 .record 文件

        前面的文章 TensorFlow 训练自己的目标检测器 中的第二部分第 2 小节中我们已经预先说过,会在后续的文章中阐述怎么自定义的将图像转化为 .record 文件,今天我们就来说一说这件事。

        在文章 TensorFlow 训练 CNN 分类器 中我们生成了 50000 张 28 x 28 像素的图像,我们的目标就是将这些图像全部写入到一个后缀为 .record 的文件中。 .recordtfrecord 文件是 TensorFlow 中的标准数据读写格式,它是一种能够高效读写的二进制文件,能够快速的复制、移动、读写和存储等。

        在文章 TensorFlow 训练 CNN 分类器 和文章 TensorFlow-slim 训练 CNN 分类模型 中,我们在训练模型时导入数据的方式都是一次性的将所有图像读入,然后循环的从中选择一个批量来训练。这对于小数据集来说不会产生问题,但如果训练数据异常大,那么很可能由于内存限制无法一次性将说有数据导入,这样前面的训练方式便不能采用了。此时,我们可以将数据转化为 .record 文件格式,然后再分批次的、逐步的读入 .record 文件进行训练。

        要将图像写入 .record 文件,首先要将图像编码为字符或数字特征,这需要调用类 tf.train.Feature。然后,在调用 tf.train.Example 将特征写入协议缓冲区。最后,通过类 tf.python_io.TFRecordWriter 将数据写入到 .record 文件中。比如,我们将前面提到的 50000 张图像写入 train.record 文件,使用如下代码(命名为 generate_tfrecord.py):

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Mar 26 09:02:10 2018

@author: shirhe-lyh
"""

"""Generate tfrecord file from images.

Example Usage:
---------------
python3 train.py \
    --images_path: Path to the training images (directory).
    --output_path: Path to .record.
"""

import glob
import io
import os
import tensorflow as tf

from PIL import Image

flags = tf.app.flags

flags.DEFINE_string('images_path', None, 'Path to images (directory).')
flags.DEFINE_string('output_path', None, 'Path to output tfrecord file.')
FLAGS = flags.FLAGS


def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def int64_list_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def bytes_list_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def float_list_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def create_tf_example(image_path):
    with tf.gfile.GFile(image_path, 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)
    width, height = image.size
    label = int(image_path.split('_')[-1].split('.')[0])
    
    tf_example = tf.train.Example(
        features=tf.train.Features(feature={
            'image/encoded': bytes_feature(encoded_jpg),
            'image/format': bytes_feature('jpg'.encode()),
            'image/class/label': int64_feature(label),
            'image/height': int64_feature(height),
            'image/width': int64_feature(width)}))
    return tf_example


def generate_tfrecord(images_path, output_path):
    writer = tf.python_io.TFRecordWriter(output_path)
    for image_file in glob.glob(images_path):
        tf_example = create_tf_example(image_file)
        writer.write(tf_example.SerializeToString())
    writer.close()
    
    
def main(_):
    images_path = os.path.join(FLAGS.images_path, '*.jpg')
    images_record_path = FLAGS.output_path
    generate_tfrecord(images_path, images_record_path)
    
    
if __name__ == '__main__':
    tf.app.run()

在该文件目录的终端执行:

python3 generate_tfrecord.py \
    --images_path /home/.../datasets/images \
    --output_path /home/.../datasets/train.record

便会在输出路径下生成 train.record 文件。以上代码中,最重要的部分是:1. 函数 create_tf_example,该函数首先得到图像的二进制格式、图像的宽和高、以及图像对应的类标号等,然后将图像的这些信息写入协议缓冲区;2. 函数 generate_tfrecord,该函数使用 tf.python_io.TFRecordWriter 类将协议缓冲区内的数据写入到 .record 文件中。

二、读取 .record 文件

        一旦将图像转化为了 .record 文件,接下来我们关心的就是怎么读取这个 .record 文件用于模型训练了。这可以借助我们前面使用过的模块 tf.contrib.slim

slim = tf.contrib.slim

def get_record_dataset(record_path,
                       reader=None, image_shape=[28, 28, 3], 
                       num_samples=50000, num_classes=10):
    """Get a tensorflow record file."""
    if not reader:
        reader = tf.TFRecordReader
        
    keys_to_features = {
        'image/encoded': 
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': 
            tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/class/label': 
            tf.FixedLenFeature([1], tf.int64, default_value=tf.zeros([1], 
                               dtype=tf.int64))}
        
    items_to_handlers = {
        'image': slim.tfexample_decoder.Image(shape=image_shape, 
                                              #image_key='image/encoded',
                                              #format_key='image/format',
                                              channels=3),
        'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[])}
    
    decoder = slim.tfexample_decoder.TFExampleDecoder(
        keys_to_features, items_to_handlers)
    
    labels_to_names = None
    items_to_descriptions = {
        'image': 'An image with shape image_shape.',
        'label': 'A single integer between 0 and 9.'}
    return slim.dataset.Dataset(
        data_sources=record_path,
        reader=reader,
        decoder=decoder,
        num_samples=num_samples,
        num_classes=num_classes,
        items_to_descriptions=items_to_descriptions,
        labels_to_names=labels_to_names)

主要是借助了 tf.contib.slim 模块中的

slim.dataset.Dataset(data_sources, reader, decoder,
                  num_samples, items_to_descriptions,
                  **kwargs)

slim.tfexample_decoder.TFExampleDecoder(keys_to_features,
                                        items_to_handlers)

这两个类。

        使用时,直接传入 train.record 路径即可:

dataset = get_record_dataset('./xxx/train.record')
data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
image, label = data_provider.get(['image', 'label'])

函数 get_record_dataset 返回 slim.dataset.Dataset 类的一个对象,之后通过类

slim.dataset_data_provider.DatasetDataProvider(dataset, num_readers=1,
                                               reader_kwargs=None,
                                               shuffle=True, num_epochs=None,
                                               common_queue_capacity=256,
                                               common_queue_min=128,
                                               record_key='record_key',
                                               seed=None, scope=None)

get 方法得到图像和类标号的序列数据。参数 num_readers=1 表示一次读取一个数据,即一次读取一张图像,因此实际使用是还需要使用函数 tf.train.batch 将数据形成批量再用于训练,见下一篇文章。

预告:下一篇文章将说明怎么完全使用 tf.contrib.slim 来构建和训练模型。

上一篇下一篇

猜你喜欢

热点阅读