28-二进制文件读取分析

2019-10-05  本文已影响0人  jxvl假装

二进制文件的读取是按照样本的bytes读取

api

"""
tf.FixedLengthRecordReader(record_bytes)
    要读取每个记录是固定数量字节的二进制文件
    record_bytes:整型,指定每次读取的字节数
    return:读取器实例
"""

案例

注意:cifar数据集已经事先下载好

import tensorflow as tf

# 定义cifar的数据等命令行参数
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string("cifar_dir", "cifar-10-batches-py", "文件的目录")

class CifarRead():
    """
    完成读取二进制文件,写进tfrecords,读取tfrecords
    """

    def __init__(self, filelist):
        self.file_list = filelist   #文件列表
        #定义读取图片的一些属性
        self.height = 32
        self.width = 32
        self.channel = 3
        #存储的字节
        self.label_bytes = 1
        self.image_bytes = self.height * self.width * self.channel
        self.bytes = self.label_bytes + self.image_bytes

    def read_and_decode(self):
        #构造文件队列
        file_queue = tf.train.string_input_producer(self.file_list)
        #构造二进制文件读取器
        reader = tf.FixedLengthRecordReader(self.bytes)
        key, value = reader.read(file_queue)
        #解码内容
        print(value)
        #二进制文件的解码
        label_image = tf.decode_raw(value, out_type=tf.uint8)
        print(label_image)
        #分割图片和标签:特征值和目标值
        label = tf.slice(label_image, [0], [self.label_bytes])
        image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])
        print("label:", label)
        print("image:", image)
        #对图片的特征数据进行形状的改变 [3072] --> [32, 32, 3]
        image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
        print("image_reshape:", image_reshape)

        #批处理数据
        image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
        print(image_batch, label_batch)
        return image_batch, label_batch


import os
if __name__ == "__main__":
    #找到文件,放入列表 路径+名字 ->列表当中
    file_name = os.listdir(FLAGS.cifar_dir)
    file_list = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if "0" <= file[-1] <= "9"]
    print(file_list)
    cf = CifarRead(file_list)
    image_batch, label_batch = cf.read_and_decode()
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord=coord)
        print(sess.run([image_batch, label_batch]))
        coord.request_stop()
        coord.join()
上一篇 下一篇

猜你喜欢

热点阅读