1. Tensorflow实战学习:TFRecord读取数据

2017-11-30  本文已影响0人  闪电侠悟空

数据读取的基本方式参见CS 20SIInput Pipeline部分,Tensorflow主要有两种加载数据的方式:

  1. Feeding:给出placeholder,然后在session中传递参数的方式输入数据。
  2. Reading from files: 不显示的利用用占位符,直接利用从文件读取生成队列,然后利用tf.cast函数直接将数据丢入到tensorflow的Graph中。

超大数据文件的主流读取的方式是第二种,并且最好是使用Tensorflow自带的TFRecord文件格式,TFRecord使用方法也比较简单,(要很好的理解其中的队列和多线程的原理,请看CS20SI课程提供的PPT),下面是实现的网页参考:

tf.train.shuffle_batch的使用说明

tf.train.batch和tf.train.shuffle_batch的理解:详细解释了这两个batch函数使用的不同。min_after_dequeue值越大,数据越乱,为了效率,个人认为保持capacity值的1/2到3/4就足够乱了。

生成和读取实验

输入图像及亮点比例标签

下面是具体的实现代码:

'''
TFRecord Study
数据的写入,与数据的读取
Author: 闪电侠悟空
Date: 2017-11-30
'''
from time import sleep

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

TRAIN_NUM = 10000

def write2tfrecord():
    '''
    Write to TFRecord files
    '''
    # Step 1. construct the TFRecord Writer
    writer = tf.python_io.TFRecordWriter(path='IMBD.tfrecords')

    for (threshold,i) in zip(np.linspace(0,1,TRAIN_NUM,dtype=np.float32),range(TRAIN_NUM)):
        print(threshold,'and ', i,'is saving!')
        prob = np.random.uniform(0,1,[64,64]) # construct the data set
        image = np.uint8(prob<threshold)*255
        print(type(image[1,1]))

        # Step 2. to bytes
        image_raw = image.tostring()

        # Step 3. construct the example
        y = tf.train.Feature(float_list=tf.train.FloatList(value=[threshold]))
        x = tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw]))
        z = tf.train.Feature(int64_list=tf.train.Int64List(value = [i]))

        example = tf.train.Example(features=tf.train.Features(feature = {"percent":y,"number":z,"raw_image":x }))

        # Step 4. write the example to the file
        writer.write(example.SerializeToString())

        pass
    #Step 5. close the writer
    writer.close()


def readanddecode():
    filename_queue = tf.train.string_input_producer(['IMBD.tfrecords'])

    reader = tf.TFRecordReader()
    _,serialized_example = reader.read(filename_queue) # Reture the file name and content
    features = tf.parse_single_example(serialized_example,features={"percent":tf.FixedLenFeature([],tf.float32),\
                                                                    "number":tf.FixedLenFeature([],tf.int64),\
                                                                    "raw_image":tf.FixedLenFeature([],tf.string)})
    img = tf.decode_raw(features["raw_image"],tf.uint8)
    img = tf.reshape(img,[64,64])
    img = tf.cast(img,tf.uint8)

    i = tf.cast(features["number"],tf.int64)
    percent = tf.cast(features["percent"],tf.float32)
    return  img,i,percent

def mainloop():
    img, i, percent = readanddecode()#get a single data
    img_batch,i_batch = tf.train.shuffle_batch([img,i],batch_size=20,capacity=10000,min_after_dequeue=9999)

    init = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init)
        threads = tf.train.start_queue_runners(sess=sess)
        for j in range(502):
            val_images, val_is = sess.run([img_batch,i_batch])
            print(val_is)

if __name__ =="__main__":
    #write2tfrecord()
    #readanddecode()
    mainloop()
上一篇下一篇

猜你喜欢

热点阅读