Tensorflow 自定义生成tfrecord文件

2019-05-06  本文已影响0人  技术大渣渣

一.TFRecord简介

TensorFlow提供了TFRecord的格式来统一存储数据,它是一种能够将图像数据和标签放在一起的二进制文件, 这种格式可以使TensorFlow的数据集更容易与网络应用架构相匹配,并可以在Tensorflow中快速的复制,移动,读取,存储 。

二.TFRecord生成

以kaggle比赛猫狗数据集(下载链接:https://pan.baidu.com/s/13hw4LK8ihR6-6-8mpjLKDA 密码:dmp4)为例生成TFRecord文件。
1.首先下载数据集,通过如下data_provider.py将数据分成train和validation两部分,并将划分后的图片路径和label信息分别保存到train_imagePath_label.json和val_imagePath_label.json文件中。

import glob
import json
import os
import random
"""
Get image path and label
"""


def split_train_val_datasets(images_path, tarin_ratio=0.90):
    """Split image to training and validation"""
    if not os.path.exists(images_path):
        raise ValueError('images_path does not exit')

    image_files = glob.glob(os.path.join(images_path, '*.jpg'))
    random.shuffle(image_files)
    num_train_samples = int(len(image_files)*tarin_ratio)

    train_files = image_files[:num_train_samples]
    val_files = image_files[num_train_samples:]

    train_dict = get_label_dict(train_files)
    val_dict = get_label_dict(val_files)

    return train_dict, val_dict


def get_label_dict(image_files=None):
    if image_files is None:
        return None
    label_dict = {}
    for image_file in image_files:
        image_name = image_file.split('\\')[-1]
        if image_name.startswith('cat'):
            label_dict[image_file] = 0
        elif image_name.startswith('dog'):
            label_dict[image_file] = 1

    return label_dict


def write_annotation_json(images_path, train_json_output_path,
                          val_jason_output_path):
    """Saving training and validation annotions."""
    train_files_dict, val_files_dict = split_train_val_datasets(images_path)

    train_json = json.dumps(train_files_dict)
    with open(train_json_output_path, 'w') as writer:
        json.dump(train_json, writer)

    val_json = json.dumps(val_files_dict)
    with open(val_jason_output_path, 'w') as writer:
        json.dump(val_json, writer)


def provide(annotion_path=None):
    """Return images_path and class labels.
    Args:
        annotation_path:Path to an antation's json file.


    Returns:
        annotation_dict: A dictionary containing the paths of images and the class labels of each image.
    """
    if not os.path.exists(annotion_path):
        raise ValueError('annotion path does not exist.')

    with open(annotion_path, 'r') as reader:
        annotion_str = json.load(reader)
        annotion_d = json.loads(annotion_str)

    annotion_dict = {}

    for image_name, label in annotion_d.items():
        annotion_dict[image_name] = label
    return annotion_dict


if __name__ == '__main__':

    images_path = './data'
    train_annotation_path = './datasets/train/train_imagePath_label.json'
    val_annotation_path = './datasets/val/val_imagePath_label.json'

    write_annotation_json(
        images_path, train_annotation_path, val_annotation_path)

    provide(train_annotation_path)

2.通过generate_tfrecord.py生成.record文件:

"""
Generate tfrecord from images and labels.

Example Usage:

python3 generate_tfrecord.py
       --image_path: Path to images directory.
       --train_annotation_path: Path to train annotation .json file.
       --train_output_path: Path to save train tfrecord
       --val_annotation_path: Path to val annotation .json file.
       --val_output_path: Path to save val tfrecord.
       --resize_size: Resize image to fixed size.

"""
import tensorflow as tf
import data_provider
from PIL import Image
import io

flags = tf.app.flags

flags.DEFINE_string('image_path',
                    './data-org',
                    'Path to images directory.')
flags.DEFINE_string('train_annotation_path',
                    './datasets/train/train_imagePath_label.json',
                    'Path to train annotation json file.')
flags.DEFINE_string('train_output_path',
                    './datasets/train/train.record',
                    'Path to save train tfrecord.')

flags.DEFINE_string('val_annotation_path',
                    './datasets/val/val_imagePath_label.json',
                    'Path to val annotation .json file.')
flags.DEFINE_string('val_output_path',
                    './datasets/val/val.record',
                    'Path to save val tfrecord.')
flags.DEFINE_integer('resize_size', 224, 'Resize image to fixed size.')

FLAGS = flags.FLAGS


def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def int64_list_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def bytes_list_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def float_list_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def create_tf_example(image_path, label, resize_size=None):
    with tf.gfile.GFile(image_path, 'rb') as fid:  # 若要读取二进制文件,比如图片、视频等,使用'rb'模式打开文件
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)  # 从内存读取二进制数据
    image = Image.open(encoded_jpg_io)

    width, height = image.size
    # Resize
    if resize_size is not None:
        width = resize_size
        height = resize_size

        image = image.resize((width, height), Image.ANTIALIAS)
        bytes_io = io.BytesIO()
        image.save(bytes_io, format='JPEG')  # 往内存中写二进制数据
        encoded_jpg = bytes_io.getvalue()

   #tf.train.Example:用于写入tfrecords文件
   #features : tf.train.Features类型的特征实例
   #返回example协议格式块
    tf_example = tf.train.Example(
        features=tf.train.Features(feature={   
            'image/encoded': bytes_feature(encoded_jpg),
            'image/format': bytes_feature('jpg'.encode()),
            'image/path': bytes_feature(image_path.encode('utf8')),
            'image/class/label': int64_feature(label),
            'image/height': int64_feature(height),
            'image/width': int64_feature(width)}))
    return tf_example


def generate_tfrecord(annotation_dict, output_path, resize_size=None):
    num_valid_tf_examples = 0
    #建立TFRecord存储器
    writer = tf.python_io.TFRecordWriter(output_path)
    for image_path, label in annotation_dict.items():
        if not tf.gfile.GFile(image_path):
            print('{} not exits'.format(image_path))
            continue
        tf_example = create_tf_example(image_path, label, resize_size)
       #序列转换成字符串
       writer.write(tf_example.SerializeToString())
        num_valid_tf_examples += 1

        if num_valid_tf_examples % 100 == 0:
            print('{} tf examples created.'.format(num_valid_tf_examples))
    writer.close()
    print('Total {} tf examples created.'.format(num_valid_tf_examples))


def main(_):
    images_path = FLAGS.image_path
    train_annotation_path = FLAGS.train_annotation_path
    train_record_path = FLAGS.train_output_path
    val_annotation_path = FLAGS.val_annotation_path
    val_record_path = FLAGS.val_output_path
    resize_size = FLAGS.resize_size

    # Write json
    data_provider.write_annotation_json(
        images_path, train_annotation_path, val_annotation_path)

    train_annotation_dict = data_provider.provide(train_annotation_path)
    val_annotation_dict = data_provider.provide(val_annotation_path)

    generate_tfrecord(train_annotation_dict, train_record_path, resize_size)
    generate_tfrecord(val_annotation_dict, val_record_path, resize_size)


if __name__ == '__main__':
    tf.app.run()

三.查看生成TFRecord是否ok

运行view_tfrecord.py,实现对.record文件中图像和label的解析:

"""
Created on 04/23/2019

@author:sunjiankang

View and check created tfrecord.

Example Usage:

python3 view_tfrecord.py
       --tfrecord_path: Path to tfrecord.
       --num_samples: Num of samples.
       --num_classes: Num of classes.

"""
import tensorflow as tf
import numpy as np
from PIL import Image
import cv2
slim = tf.contrib.slim

flags = tf.app.flags

flags.DEFINE_string('tfrecord_path',
                    './datasets/val/val.record',
                    'Path to tfrecord.')
flags.DEFINE_integer('num_samples', 2000, 'Num of samples.')
flags.DEFINE_integer('num_classes', 2, 'Num of classes.')

FLAGS = flags.FLAGS


def get_record_dataset(record_path,
                       reader=None,
                       num_samples=2000,
                       num_classes=2):
    """Get a tensorflow record file."""
    if not reader:
        reader = tf.TFRecordReader
    #将example反序列化成存储之前的格式,由tf完成
    keys_to_features = {
        'image/encoded':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format':
            tf.FixedLenFeature((), tf.string, default_value='jpg'),
        'image/path':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/class/label':
            tf.FixedLenFeature(
                [1], tf.int64, default_value=tf.zeros([1], dtype=tf.int64)),
        'image/height':
            tf.FixedLenFeature([], tf.int64),
        'image/width':
            tf.FixedLenFeature([], tf.int64)}
    #将反序列化的数据组装成更高级的格式。由slim完成
    items_to_handlers = {
        'image': slim.tfexample_decoder.Image(image_key='image/encoded',
                                              format_key='image/format',
                                              channels=3),
        'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]),
        'path': slim.tfexample_decoder.Tensor('image/path', shape=[]),
        'height': slim.tfexample_decoder.Tensor('image/height', shape=[]),
        'width': slim.tfexample_decoder.Tensor('image/width', shape=[])}

    #解码器,进行解码
    decoder = slim.tfexample_decoder.TFExampleDecoder(
        keys_to_features, items_to_handlers)

    labels_to_names = None
    items_to_descriptions = {
        'image': 'An image with shape image_shape.',
        'label': 'A single integer between 0 and 1.',
        'path': 'image path',
        'height': 'Image height.',
        'width': 'Image width.'}
    #dataset对象定义了数据集的文件位置,解码方式等元信息
    return slim.dataset.Dataset(
        data_sources=record_path,
        reader=reader,
        decoder=decoder,
        num_samples=num_samples,
        num_classes=num_classes,
        items_to_descriptions=items_to_descriptions,
        labels_to_names=labels_to_names)


def main(_):
    num_samples = FLAGS.num_samples
    num_classes = FLAGS.num_classes
    tfrecord_path = FLAGS.tfrecord_path

    dataset = get_record_dataset(tfrecord_path,
                                 num_samples=num_samples,
                                 num_classes=num_classes)
    #provider对象根据dataset信息读取数据
    data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
    #获取数据
    [image, label, path, height, width] = data_provider.get(
        ['image', 'label', 'path', 'height', 'width'])

    with tf.Session() as sess:

        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        
        #调用 tf.train.Coordinator() 来创建一个线程协调器,用来管理之后在Session中启动的所有线程;
        #调用tf.train.start_queue_runners, 启动入队线程,由多个或单个线程,按照设定规则,把文件读入Filename Queue中。
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(num_samples):
            img, l, p, h, w = sess.run([image, label, path, height, width])
            img = np.reshape(img, [256, 256, 3])
            img = np.array(img, dtype=np.uint8)
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) #PIL读取图像是RGB格式,opencv显示需要BGR,在这里进行格式转换
            print('label: {}, image path: {}'.format(l, p.decode('utf8')))
            cv2.imshow('img', img)
            if cv2.waitKey(1) == 27:
                break
            if i % 100 == 0:
                print('{} example checked.'.format(i))

        coord.request_stop()
        coord.join(threads)


if __name__ == '__main__':
    tf.app.run()
上一篇下一篇

猜你喜欢

热点阅读