大数据 爬虫Python AI SqlPython小哥哥

基于python的mysql复制工具!

2019-06-26  本文已影响0人  14e61d025165

Tensorflow 使用笔记:TFRecords

Tensorflow 的数据输入现在主要有两种形式:直接使用 Python 和 TFRecords . 在图像的项目中看到比较多的是直接自己实现dataprovider ,在 NLP 项目中见到比较多先做生成TFRecords 然后利用 tf.data.TFRecordDataset 来读取。我习惯 TRFRecords 的方式来实现。主要因为可以把数据清洗和模型处理的过程分开,二者不是混杂在一起。TFRecords 作为中间格式存在,生成什么样的 TFRecord 完全决定于你对要做的问题的理解,因为这里定义了你将要用到的特征。

Python学习交流群:1004391443

我们通常有图像或文本这样的原始数据,拿图像分类或文本分类任务来说。我们的输入特征可能是图像的像素矩阵或者文本中词对应的 ID 而分类标签可能是对应标签的Id 或者甚至直接是字符串等等。Tensorflow 把这样的数据抽象成 Example 。 Example 有很多 Feature 这些 Feature 的数据类型主要有三种。TFrecord 中存储的就是 Example 对象对应的二进制数据,确切的说是使用 protobuf 序列化的二进制数据。在读取的使用 Tensorflow 提供的 DataSet API 在对序列化的数据解码的时候可以把想用的特征解码成对应的 Tensor 。简单的抽象和实现流程如下。

<tt-image data-tteditor-tag="tteditorTag" contenteditable="false" class="syl1561531786839" data-render-status="finished" data-syl-blot="image" style="box-sizing: border-box; cursor: text; color: rgb(34, 34, 34); font-family: "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "WenQuanYi Micro Hei", "Helvetica Neue", Arial, sans-serif; font-size: 16px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; white-space: pre-wrap; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; background-color: rgb(255, 255, 255); text-decoration-style: initial; text-decoration-color: initial; display: block;"> image

<input class="pgc-img-caption-ipt" placeholder="图片描述(最多50字)" value="" style="box-sizing: border-box; outline: 0px; color: rgb(102, 102, 102); position: absolute; left: 187.5px; transform: translateX(-50%); padding: 6px 7px; max-width: 100%; width: 375px; text-align: center; cursor: text; font-size: 12px; line-height: 1.5; background-color: rgb(255, 255, 255); background-image: none; border: 0px solid rgb(217, 217, 217); border-radius: 4px; transition: all 0.2s cubic-bezier(0.645, 0.045, 0.355, 1) 0s;"></tt-image>

Example

在创建 TFRecords 的过程中需要对Example 的定义比较好的理解。 数据类型抽象成三种:bytes, float, int64 , Feature 的基本组成单元是这三种数据的 list 定义如下:

<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">message BytesList {
repeated bytes value = 1;
}
message FloatList {
repeated float value = 1 [packed = true];
}
message Int64List {
repeated int64 value = 1 [packed = true];
}
</pre>

Feature 就是 BytesList, FloatList,Int64List 的封装

<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">message Feature {
// Each feature can be exactly one kind.
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
</pre>

Feature 可以组成Map 状态的 Features :

<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">message Features {
// Map from feature name to feature.
map<string, Feature> feature = 1;
};
</pre>

还可以组成 FeatureList

<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">message FeatureList {
repeated Feature feature = 1;
};
</pre>

二者结合还可以产生下面类型

<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">message FeatureLists {
// Map from feature name to feature list.
map<string, FeatureList> feature_list = 1;
};
</pre>

如果对protobuf 的语法有了解的话,这些定义就很明了了。

Exmaple 的是 map 型的 Feature 的组合

<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">message Example {
Features features = 1;
};
</pre>

序列状态的 Example

<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">message SequenceExample {
Features context = 1;
FeatureLists feature_lists = 2;
};
</pre>

了解这些定义之后,我们要做的就是把各种原始数据转成 bytes ,float ,int 类型然后构造成 Feature 然后组成 Example 序列到 文件中就好了 下面是一个完整的例子把 mnist 的数据序列化到 TFRecords

<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">#!/usr/bin/env python

-- coding:utf-8 --

author: wu.zheng midday.me

import mnist
import cv2
import os
import sys
import numpy as np
import tensorflow as tf
def _bytes_feature(value):
if not isinstance(value, list):
value = [value]
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def write_tfrecord(data, labels, out_data_path):
writer = tf.python_io.TFRecordWriter(out_data_path)
counter = 0
total_count = len(data)
for image, label in zip(data, labels):
counter += 1
image = np.array(image)
image = image.reshape((28, 28))
is_success, image_buffer = cv2.imencode(".jpg", image)
if not is_success:
continue
label_value = [0] * 10
label_value[label] = 1
image_feature = _bytes_feature(image_buffer.tostring());
label_feature = _int64_feature(label_value)
features = tf.train.Features(feature={"image":image_feature, "label":label_feature})
example = tf.train.Example(features=features)
writer.write(example.SerializeToString())
sys.stdout.write("\r>>Writing to {:s} {:d}/{:d}".format(out_data_path, counter, total_count))
sys.stdout.flush()
writer.close()
sys.stdout.write("\n")
sys.stdout.write(">>{:s} write finish. ".format(out_data_path))
def create_mnist_tfrecord(in_data_floder, out_data_floder ):
meta_data = mnist.MNIST(in_data_floder)
train_data, train_labels = meta_data.load_training()
test_data, test_labels = meta_data.load_testing()
train_tf_record_path = os.path.join(out_data_floder, 'train_mnist.tfrecord')
test_tf_record_path = os.path.join(out_data_floder, 'test_mnist.tfrecord')
write_tfrecord(train_data, train_labels, train_tf_record_path)
write_tfrecord(test_data, test_labels, test_tf_record_path)
if name == "main":
# datasets/mnist 下存放的是解压后的 mnist 数据,
in_data_floder = "./datasets/mnist"
out_data_floder = "./datasets/mnist_tfrecord"
create_mnist_tfrecord(in_data_floder, out_data_floder)
</pre>

上门例子还有些需要改进的地方,通常这个构建过程相对比较慢,几百万的数据可能会花费一两天的时间,所以需要多线程处理,生成一个 TFRecords 文件可能会很大,也不方便分布式,通常会把生成的文件划分成很多份。

Input_fn

有了 TFRecords 我们可以实现一个 input_fn 就好了,如果后面我们有新的数据要添加进来继续训练我们的模型,也只需要按照上门的步骤处理成 TFRecords, input_fn 不用做改变。在 input_fn 里面我们可以做数据增强等一些处理

在这里有个比较麻烦的是 Example 中定义的 Feature 会有与之对应的 tf.data.Feature. 有 VarLenFeature , SparseFeature , FixedLenFeature , FixedLenSequenceFeature 使用的是后选择合适的 Feature 就好了,他们本质上是对应这不通形态的 Tensor 比如 VarLenFeature 会产生一个 SparseTensor

下面是 mnist 数据的 一个 input_fn 的 实现:

<pre spellcheck="false" style="box-sizing: border-box; margin: 5px 0px; padding: 5px 10px; border: 0px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-variant-numeric: inherit; font-variant-east-asian: inherit; font-weight: 400; font-stretch: inherit; font-size: 16px; line-height: inherit; font-family: inherit; vertical-align: baseline; cursor: text; counter-reset: list-1 0 list-2 0 list-3 0 list-4 0 list-5 0 list-6 0 list-7 0 list-8 0 list-9 0; background-color: rgb(240, 240, 240); border-radius: 3px; white-space: pre-wrap; color: rgb(34, 34, 34); letter-spacing: normal; orphans: 2; text-align: left; text-indent: 0px; text-transform: none; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; text-decoration-style: initial; text-decoration-color: initial;">#!/usr/bin/env python

-- coding:utf-8 --

author: wu.zheng midday.me

import tensorflow as tf
def _decode_record(record_proto):
feature_map = {
"image": tf.FixedLenFeature((), tf.string),
'label': tf.VarLenFeature(tf.int64),
}
features = tf.parse_single_example(record_proto, features=feature_map)
image = features['image']
image = tf.image.decode_jpeg(image, channels=1)
image = tf.cast(image, tf.float32)
paddings = tf.constant([[2, 2], [2, 2], [0,0]])
image = tf.pad(image, paddings, mode='CONSTANT', constant_values=0 )
image = image / 255.0
label = features['label']
example = {"image": image, "label": label}
return example
def input_fn(tfrecord_path, batch_size, is_training):
dataset = tf.data.TFRecordDataset(tfrecord_path)
if is_training:
dataset = dataset.repeat().shuffle(buffer_size=10000)
else:
dataset = tf.repeat(1)
dataset = dataset.map(lambda x: _decode_record(x))
dataset = dataset.batch(batch_size=batch_size)
return dataset.make_one_shot_iterator()
if name == "main":
tf_record_path = "./datasets/mnist_tfrecord/train_mnist.tfrecord"
with tf.Session() as sess:
iterator = input_fn(tf_record_path, 1, True)
next_batch = iterator.get_next()
sess.run(tf.global_variables_initializer())
while True:
batch = sess.run(next_batch)
image = batch['image']
print(image.shape)
exit(0)
</pre>

上一篇 下一篇

猜你喜欢

热点阅读