1. Tensorflow实战学习:TFRecord读取数据
2017-11-30 本文已影响0人
闪电侠悟空
数据读取的基本方式参见CS 20SI 的Input Pipeline
部分,Tensorflow主要有两种加载数据的方式:
- Feeding:给出placeholder,然后在session中传递参数的方式输入数据。
- Reading from files: 不显示的利用用占位符,直接利用从文件读取生成队列,然后利用tf.cast函数直接将数据丢入到tensorflow的Graph中。
超大数据文件的主流读取的方式是第二种,并且最好是使用Tensorflow自带的TFRecord文件格式,TFRecord使用方法也比较简单,(要很好的理解其中的队列和多线程的原理,请看CS20SI课程提供的PPT),下面是实现的网页参考:
tf.train.shuffle_batch的使用说明
tf.train.batch和tf.train.shuffle_batch的理解:详细解释了这两个batch函数使用的不同。min_after_dequeue值越大,数据越乱,为了效率,个人认为保持capacity值的1/2到3/4就足够乱了。
生成和读取实验
输入图像及亮点比例标签下面是具体的实现代码:
'''
TFRecord Study
数据的写入,与数据的读取
Author: 闪电侠悟空
Date: 2017-11-30
'''
from time import sleep
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
TRAIN_NUM = 10000
def write2tfrecord():
'''
Write to TFRecord files
'''
# Step 1. construct the TFRecord Writer
writer = tf.python_io.TFRecordWriter(path='IMBD.tfrecords')
for (threshold,i) in zip(np.linspace(0,1,TRAIN_NUM,dtype=np.float32),range(TRAIN_NUM)):
print(threshold,'and ', i,'is saving!')
prob = np.random.uniform(0,1,[64,64]) # construct the data set
image = np.uint8(prob<threshold)*255
print(type(image[1,1]))
# Step 2. to bytes
image_raw = image.tostring()
# Step 3. construct the example
y = tf.train.Feature(float_list=tf.train.FloatList(value=[threshold]))
x = tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw]))
z = tf.train.Feature(int64_list=tf.train.Int64List(value = [i]))
example = tf.train.Example(features=tf.train.Features(feature = {"percent":y,"number":z,"raw_image":x }))
# Step 4. write the example to the file
writer.write(example.SerializeToString())
pass
#Step 5. close the writer
writer.close()
def readanddecode():
filename_queue = tf.train.string_input_producer(['IMBD.tfrecords'])
reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue) # Reture the file name and content
features = tf.parse_single_example(serialized_example,features={"percent":tf.FixedLenFeature([],tf.float32),\
"number":tf.FixedLenFeature([],tf.int64),\
"raw_image":tf.FixedLenFeature([],tf.string)})
img = tf.decode_raw(features["raw_image"],tf.uint8)
img = tf.reshape(img,[64,64])
img = tf.cast(img,tf.uint8)
i = tf.cast(features["number"],tf.int64)
percent = tf.cast(features["percent"],tf.float32)
return img,i,percent
def mainloop():
img, i, percent = readanddecode()#get a single data
img_batch,i_batch = tf.train.shuffle_batch([img,i],batch_size=20,capacity=10000,min_after_dequeue=9999)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
threads = tf.train.start_queue_runners(sess=sess)
for j in range(502):
val_images, val_is = sess.run([img_batch,i_batch])
print(val_is)
if __name__ =="__main__":
#write2tfrecord()
#readanddecode()
mainloop()