29-Tfrecords文件的读取与存储

2019-10-05  本文已影响0人  jxvl假装

tfrecords是tensorflow自带的文件格式,也是一种二进制文件:

  1. 方便读取和移动
  2. 是为了将二进制数据和标签(训练的类别标签)数据存储在同一个文件中
  3. 文件格式:*.tfrecords
  4. 写如文件的内容:Example协议块,是一种类字典的格式

TFRecords存储的api

"""api
1、建立TFRecord存储器
tf.python_io.TFRecordWriter(path)
    写入tfrecords文件
    path: TFRecords文件的路径
    return:写文件
method
    write(record):向文件中写入一个字符串记录
    close():关闭文件写入器
注:字符串为一个序列化的Example,Example.SerializeToString()

2、构造每个样本的Example协议块
tf.train.Example(features=None)
    写入tfrecords文件
    features:tf.train.Features类型的特征实例
    return:example格式协议块

tf.train.Features(feature=None)
    构建每个样本的信息键值对
    feature:字典数据,key为要保存的名字,
    value为tf.train.Feature实例
    return:Features类型

tf.train.Feature(**options)
    **options:例如
    bytes_list=tf.train. BytesList(value=[Bytes])
    int64_list=tf.train. Int64List(value=[Value])
    tf.train. Int64List(value=[Value])
    tf.train. BytesList(value=[Bytes])
    tf.train. FloatList(value=[value])

同文件阅读器流程,中间需要解析过程

解析TFRecords的example协议内存块
tf.parse_single_example(serialized,features=None,name=None)
    解析一个单一的Example原型
    serialized:标量字符串Tensor,一个序列化的Example
    features:dict字典数据,键为读取的名字,值为FixedLenFeature
    return:一个键值对组成的字典,键为读取的名字

tf.FixedLenFeature(shape,dtype)
    shape:输入数据的形状,一般不指定,为空列表
    dtype:输入数据类型,与存储进文件的类型要一致
    类型只能是float32,int64,string

"""

读取tfrecords的api与流程

"""api
同文件阅读器流程,中间需要解析过程

解析TFRecords的example协议内存块
tf.parse_single_example(serialized,features=None,name=None)
    解析一个单一的Example原型
    serialized:标量字符串Tensor,一个序列化的Example
    features:dict字典数据,键为读取的名字,值为FixedLenFeature
    return:一个键值对组成的字典,键为读取的名字
"""
"""流程
tf.FixedLenFeature(shape,dtype)
    shape:输入数据的形状,一般不指定,为空列表
    dtype:输入数据类型,与存储进文件的类型要一致
    类型只能是float32,int64,string


1、构造TFRecords阅读器
2、解析Example
3、转换格式,bytes解码
"""

tfrecords文件的读取

import tensorflow as tf

# 定义cifar的数据等命令行参数
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("cifar_dir", "cifar-10-batches-py", "文件的目录")
tf.app.flags.DEFINE_string("cifar_tfrecords", "./tmp/cifar.tfrecords", "存进tfrecords的文件")

class CifarRead():
    """
    完成读取二进制文件,写进tfrecords,读取tfrecords
    """

    def __init__(self, filelist):
        self.file_list = filelist  # 文件列表
        # 定义读取图片的一些属性
        self.height = 32
        self.width = 32
        self.channel = 3
        # 存储的字节
        self.label_bytes = 1
        self.image_bytes = self.height * self.width * self.channel
        self.bytes = self.label_bytes + self.image_bytes

    def read_and_decode(self):
        # 构造文件队列
        file_queue = tf.train.string_input_producer(self.file_list)
        # 构造二进制文件读取器,并指定读取长度
        reader = tf.FixedLengthRecordReader(self.bytes)
        key, value = reader.read(file_queue)
        # 解码内容
        print(value)
        # 二进制文件的解码
        label_image = tf.decode_raw(value, out_type=tf.uint8)
        print(label_image)
        # 分割图片和标签:特征值和目标值
        label = tf.slice(label_image, [0], [self.label_bytes])  #读取标签
        image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])  #读取特征向量
        print("label:", label)
        print("image:", image)
        # 对图片的特征数据进行形状的改变 [3072] --> [32, 32, 3]
        image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
        print("image_reshape:", image_reshape)

        # 批处理数据
        image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
        print(image_batch, label_batch)
        return image_batch, label_batch

    def write_to_tfrecords(self, image_batch, label_batch):
        """
        将图片的特征值和目标值存进tfrecords
        :param image_batch: 10张图片的特征值
        :param label_batch: 10张图片的目标值
        :return: None
        """
        #建立一个tfrecords存储器
        writer = tf.python_io.TFRecordWriter(path=FLAGS.cifar_tfrecords)  #注意:tf.python_io.TFRecordWriter已经被tf.io.TFRecordWriter代替
        #循环将所有样本写入文件,每张图片样本都要构造一个example协议
        for i in range(10):
            #取出第i个图片数据的特征值和目标值
            image = image_batch[i].eval().tostring()   #.eval()获取值
            label = label_batch[i].eval()[0]    #因为是一个二维列表,所以必须取[0]
            """注意:eval必须写在session中"""
            #构造一个样本的example
            example = tf.train.Example(features=tf.train.Features(feature={
                "image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                "label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
            }))
            #写入单独的样本
            writer.write(example.SerializeToString())   #序列化后再写入文件
        #关闭
        writer.close()

    def read_from_tfrecords(self):
        #构造文件阅读器
        file_queue = tf.train.input_producer([FLAGS.cifar_tfrecords])
        #构造文件阅读器,读取内容example
        reader = tf.TFRecordReader()
        key, value = reader.read(file_queue)  #value也是一个example的序列化
        #由于存储的是example,所以需要对example解析
        features = tf.parse_single_example(value, features={
            "image":tf.FixedLenFeature(shape=[], dtype=tf.string),
            "label":tf.FixedLenFeature(shape=[], dtype=tf.int64)
        })
        print(features["image"], features["label"]) #注意:此时是tensor
        #解码内容,如果读取的string类型,需要解码,如果是int64,float32就不需要解码。因为里面都是bytes,所以需要解码
        image = tf.decode_raw(features["image"], tf.uint8)
        label = tf.cast(features["label"], tf.int32)   #label不需要解码,因为int64实际在存储的时候还是以int32存储的,不会占用那么多空间,所以这里可以直接转换成int32
        print(image, label)
        #固定图片的形状,以方便批处理
        image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
        print(image_reshape)
        #进行批处理
        image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
        return image_batch, label_batch
import os

if __name__ == "__main__":
    # 找到文件,放入列表 路径+名字 ->列表当中
    file_name = os.listdir(FLAGS.cifar_dir)
    file_list = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if "0" <= file[-1] <= "9"]
    print(file_list)
    cf = CifarRead(file_list)
    image_batch, label_batch = cf.read_and_decode()
    # image_batch, label_batch = cf.read_from_tfrecords()
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord=coord)
        print(sess.run([image_batch, label_batch]))
        #存进tfrecords文件
        print("开始存储...")
        cf.write_to_tfrecords(image_batch, label_batch) #因为这个函数里面有eval,所以必须在session里面运行
        print("结束存储...")
        # print("读取的数据:\n",sess.run([image_batch, label_batch]))
        coord.request_stop()
        coord.join()

上一篇下一篇

猜你喜欢

热点阅读