tensorflow技术解析与实战——阅读笔记

tensorflow初探八之加载数据-Tensorflow技术解

2019-01-29  本文已影响0人  欠我的都给我吐出来

加载数据

TensorFlow 作为符号编程框架,需要先构建数据流图,再读取数据,随后进行模型训练。

import tensorflow as tf
# 第二种方式:填充数据
a1 = tf.placeholder(tf.int16)
a2 = tf.placeholder(tf.int16)
b = tf.add(x1, x2)
# 用 Python 产生数据
li1 = [2, 3, 4]
li2 = [4, 0, 1]
# 打开一个会话,将数据填充给后端
with tf.Session() as sess:
  print sess.run(b, feed_dict={a1: li1, a2: li2})

TFRecords 是一种二进制文件,能更好地利用内存,更方便地复制和移动,并且不需要单独的标记文件。

从文件读取数据分为如下两个步骤:
(1)把样本数据写入 TFRecords 二进制文件;
(2)再从队列中读取。

把样本数据写入 TFRecords 二进制文件

  1. 将数据填入到 tf.train.Example 的协议缓冲区(protocolbuffer)中
 example=tf.train.Example(features=tf.train.Features(feature={
                    'height': _int64_feature(rows),
                    'width': _int64_feature(cols),
                    'depth': _int64_feature(depth),
                    'label': _int64_feature(int(labels[i].tolist)),
                    'image_raw': _bytes_feature(image_raw)
                }))
  1. 将协议缓冲区序列化为一个字符串,通过 tf.python_io.TFRecordWriter 写入 TFRecords文件
#定义一个writer
    filename=os.path.join(os.getcwd(),name+'.tfrecords')
    writer= tf.python_io.TFRecordWriter(filename)
......
#对于for i in range(num_example)中的每个example,写入文件
    writer.write(example.SerializerToString())
  1. 最后关闭writer
writer.close()

从队列中读取
一旦生成了 TFRecords 文件,接下来就可以使用队列读取数据了。主要分为 3 步:
(1)创建张量,从二进制文件读取一个样本;
(2)创建张量,从二进制文件随机读取一个 mini-batch;
(3)把每一批张量传入网络作为输入节点。

上一篇下一篇

猜你喜欢

热点阅读