tensorflow数据操作

2019-06-26  本文已影响0人  翻开日记
import tensorflow as tf
import numpy as np

"""解析数据"""
def _parse_function(example_proto):
    features = {'images': tf.FixedLenFeature((), tf.string),
                'labels': tf.FixedLenFeature((), tf.int64)}
    parsed_features = tf.parse_single_example(example_proto, features)
    data = tf.decode_raw(parsed_features['images'], tf.float32)
    return data, parsed_features['labels']


"""读取单一数据"""
def read_one_batch():
    """"""
    dataset = tf.data.TFRecordDataset(tf.gfile.Glob('data/*'))
    dataset = dataset.map(_parse_function)
    dataset = dataset.repeat(2)
    dataset = dataset.batch(32)

    iterator = dataset.make_one_shot_iterator()

    next_data = iterator.get_next()

    return next_data

"""读取指定batchsize的数据"""
def read_N_batch():
    dataset = tf.data.TFRecordDataset(tf.gfile.Glob('data/*'))
    dataset = dataset.map(_parse_function)
    dataset = dataset.repeat(2)
    batch = tf.placeholder(tf.int64, shape=[])
    dataset = dataset.batch(batch)

    iterator = dataset.make_initializable_iterator()

    return iterator.get_next()


"""读取不同类型的数据"""
def read_diff_batch():
    tr_dataset = tf.data.TFRecordDataset(tf.gfile.Glob('data/*'))
    tr_dataset = tr_dataset.map(_parse_function)
    tr_dataset = tr_dataset.repeat(2)
    tr_dataset = tr_dataset.batch(32)

    te_dataset = tf.data.TFRecordDataset(tf.gfile.Glob('data/*'))
    te_dataset = te_dataset.map(_parse_function)
    te_dataset = te_dataset.repeat(2)  # 整个数据集的循环次数
    te_dataset = te_dataset.batch(16)

    iterator = tf.data.Iterator.from_structure(tr_dataset.output_types,
                                               tr_dataset.output_shapes)

    train_op = iterator.make_initializer(tr_dataset)
    test_op = iterator.make_initializer(te_dataset)

    next_data = iterator.get_next()

    return train_op, test_op, next_data


if __name__ == '__main__':

    with tf.Session() as sess:
        train_op, test_op, next_data = read_diff_batch
        for _ in range(2):
            sess.run(train_op)
            for _ in range(3):
                print(np.shape(sess.run(next_data, )[0]))

            sess.run(test_op)
            for _ in range(2):
                print(np.shape(sess.run(next_data, )[0]))
上一篇下一篇

猜你喜欢

热点阅读