tfrecord这个锤锤

2019-12-16  本文已影响0人  林桉

什么是TFRecord?

TFRecord 是Google官方推荐的一种数据格式,是Google专门为TensorFlow设计的一种数据格式。

image.png

实际上,TFRecord是一种二进制文件,其能更好的利用内存,其内部包含了多个tf.train.Example, 而Example是protocol buffer(protobuf) 数据标准的实现,在一个Example消息体中包含了一系列的tf.train.feature属性,而 每一个feature 是一个key-value的键值对,其中,key 是string类型,而value 的取值有三种:

bytes_list: 可以存储string 和byte两种数据类型。
float_list: 可以存储float(float32)与double(float64) 两种数据类型 。
int64_list: 可以存储:bool, enum, int32, uint32, int64, uint64 。
值的一提的是,TensorFlow 源码中到处可见.proto 的文件,且这些文件定义了TensorFlow重要的数据结构部分,且多种语言可直接使用这类数据,很强大。

protobuf这个锤锤

优点:

安装 https://blog.csdn.net/xxjuanq_only_one/article/details/50465272

import "Common.proto";      // 引入Common.proto,位于Protobuf sdk中

option optimize_for = LITE_RUNTIME;

option java_package = "com.xxxx.entity.pb";    // 生成类的包名
option java_outer_classname = "PayInfo";       // 生成类的类名

message PayInfo{
    required string payid = 1;             // 支付相关的字段信息
    optional string goodinfo = 2;          // optional 为可选参数
    required string prepayid = 3;          // required为必填参数
    optional string mode = 4;
    optional int  userid = 5;
    repeated string  extra = 6;           // repeated 为数组
} 

protoc --java_out ./ ./PayInfo.proto

Why 用TFRecord 这个锤锤 ?

TFRecord 并非是TensorFlow唯一支持的数据格式,你也可以使用CSV或文本等格式,但是对于TensorFlow来说,TFRecord 是最友好的,也是最方便的。前面提到,TFRecord内部是一系列实现了protocol buffer数据标准的Example,对于大型数据,对比其余数据格式,protocol buffer类型的数据优势很明显。

转TFrecord

writer = tf.python_io.TFRecordWriter(out_file_name)  # 1. 定义 writer对象

for data in dataes:
    context = dataes[0]
    question = dataes[1]
    answer = dataes[2]

    """ 2. 定义features """
    example = tf.train.Example(
        features = tf.train.Features(
            feature = {
               'context': tf.train.Feature(
                 int64_list=tf.train.Int64List(value=context)),
               'question': tf.train.Feature(
                 int64_list=tf.train.Int64List(value=question)),
               'answer': tf.train.Feature(
                 int64_list=tf.train.Int64List(value=answer))
            }))

读取API

https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset

举个🌰

from __future__ import absolute_import, division, print_function

import csv
import requests
import tensorflow as tf
# Download Titanic dataset (in csv format).
d = requests.get("https://raw.githubusercontent.com/tflearn/tflearn.github.io/master/resources/titanic_dataset.csv")
with open("titanic_dataset.csv", "wb") as f:
    f.write(d.content)
# Generate Integer Features.
def build_int64_feature(data):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[data]))

# Generate Float Features.
def build_float_feature(data):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[data]))

# Generate String Features.
def build_string_feature(data):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[data]))

# Generate a TF `Example`, parsing all features of the dataset.
def convert_to_tfexample(survived, pclass, name, sex, age, sibsp, parch, ticket, fare):
    return tf.train.Example(
        features=tf.train.Features(
            feature={
                'survived': build_int64_feature(survived),
                'pclass': build_int64_feature(pclass),
                'name': build_string_feature(name),
                'sex': build_string_feature(sex),
                'age': build_float_feature(age),
                'sibsp': build_int64_feature(sibsp),
                'parch': build_int64_feature(parch),
                'ticket': build_string_feature(ticket),
                'fare': build_float_feature(fare),
            })
    )

# Open dataset file.
with open("titanic_dataset.csv") as f:
    # Output TFRecord file.
    with tf.io.TFRecordWriter("titanic_dataset.tfrecord") as w:
        # Generate a TF Example for all row in our dataset.
        # CSV reader will read and parse all rows.
        reader = csv.reader(f, skipinitialspace=True)
        for i, record in enumerate(reader):
            # Skip header.
            if i == 0:
                continue
            survived, pclass, name, sex, age, sibsp, parch, ticket, fare = record
            # Parse each csv row to TF Example using the above functions.
            example = convert_to_tfexample(int(survived), int(pclass), name, sex, float(age), int(sibsp), int(parch), ticket, float(fare))
            # Serialize each TF Example to string, and write to TFRecord file.
            w.write(example.SerializeToString())
# Build features template, with types.
features = {
    'survived': tf.io.FixedLenFeature([], tf.int64),
    'pclass': tf.io.FixedLenFeature([], tf.int64),
    'name': tf.io.FixedLenFeature([], tf.string),
    'sex': tf.io.FixedLenFeature([], tf.string),
    'age': tf.io.FixedLenFeature([], tf.float32),
    'sibsp': tf.io.FixedLenFeature([], tf.int64),
    'parch': tf.io.FixedLenFeature([], tf.int64),
    'ticket': tf.io.FixedLenFeature([], tf.string),
    'fare': tf.io.FixedLenFeature([], tf.float32),
}

# Create TensorFlow session.
sess = tf.Session()

# Load TFRecord data.
filenames = ["titanic_dataset.tfrecord"]
data = tf.data.TFRecordDataset(filenames)

# Parse features, using the above template.
def parse_record(record):
    return tf.io.parse_single_example(record, features=features)
# Apply the parsing to each record from the dataset.
data = data.map(parse_record)

# Refill data indefinitely.
data = data.repeat()
# Shuffle data.
data = data.shuffle(buffer_size=1000)
# Batch data (aggregate records together).
data = data.batch(batch_size=4)
# Prefetch batch (pre-load batch for faster consumption).
data = data.prefetch(buffer_size=1)

# Create an iterator over the dataset.
iterator = data.make_initializable_iterator()
# Initialize the iterator.
sess.run(iterator.initializer)

# Get next data batch.
x = iterator.get_next()

# Dequeue data and display.
for i in range(3):
    print(sess.run(x))
    print("")
上一篇 下一篇

猜你喜欢

热点阅读