机器学习之旅

使用Tensorflow的DataSet和Iterator读取数

2018-06-02  本文已影响269人  文哥的学习日记

今天在写NCF代码的时候,发现网络上的代码有一种新的数据读取方式,这里将对应的片段剪出来给大家分享下。

NCF的文章参考:https://www.jianshu.com/p/6173dbde4f53

原始数据
我们的原始数据保存在npy文件中,是一个字典类型,有三个key,分别是user,item和label:

data = np.load('data/test_data.npy').item()
print(type(data))

#output
<class 'dict'>

构建tf的Dataset
使用 tf.data.Dataset.from_tensor_slices方法,将我们的数据变成tensorflow的DataSet:

dataset = tf.data.Dataset.from_tensor_slices(data)
print(type(dataset))
#output
<class 'tensorflow.python.data.ops.dataset_ops.TensorSliceDataset'>

进一步,将我们的Dataset变成一个BatchDataset,这样的话,在迭代数据的时候,就可以一次返回一个batch大小的数据:

dataset = dataset.shuffle(1000).batch(100)
print(type(dataset))

#output
<class 'tensorflow.python.data.ops.dataset_ops.BatchDataset'>

可以看到,我们在变成batch之前使用了一个shuffle对数据进行打乱,100表示buffersize,即每取1000个打乱一次。

此时dataset有两个属性,分别是output_shapes和output_types,我们将根据这两个属性来构造迭代器,用于迭代数据。

print(dataset.output_shapes)
print(dataset.output_types)

#output
{'user': TensorShape([Dimension(None)]), 'item': TensorShape([Dimension(None)]), 'label': TensorShape([Dimension(None)])}
{'user': tf.int32, 'item': tf.int32, 'label': tf.int32}

构造迭代器
我们使用上面提到的两个dataset的属性,并使用tf.data.Iterator.from_structure方法来构造一个迭代器:

iterator = tf.data.Iterator.from_structure(dataset.output_types,
                                            dataset.output_shapes)

迭代器需要初始化:

 sess.run(iterator.make_initializer(dataset))

此时,就可以使用get_next(),方法来源源不断的读取batch大小的数据了

def getBatch():
    sample = iterator.get_next()
    print(sample)
    user = sample['user']
    item = sample['item']
    return user,item

使用迭代器的正确姿势
我们这里来计算返回的每个batch中,user和item的平均值:

users,items = getBatch()
usersum = tf.reduce_mean(users,axis=-1)
itemsum = tf.reduce_mean(items,axis=-1)

迭代器iterator只能往前遍历,如果遍历完之后还调用get_next()的话,会报tf.errors.OutOfRangeError错误,因此需要使用try-catch。

try:
    while True:
        print(sess.run([usersum,itemsum]))
except tf.errors.OutOfRangeError:
    print("outOfRange")  

如果想要多次遍历数据的话,初始化外面包裹一层循环即可:

for i in range(2):
    sess.run(iterator.make_initializer(dataset))
    try:
        while True:
            print(sess.run([usersum,itemsum]))
    except tf.errors.OutOfRangeError:
        print("outOfRange")

完整代码

import numpy as np
import tensorflow as tf


data = np.load('data/test_data.npy').item()
print(type(data))


dataset = tf.data.Dataset.from_tensor_slices(data)
print(type(dataset))
dataset = dataset.shuffle(10000).batch(100)
print(type(dataset))

print(dataset.output_shapes)
print(dataset.output_types)

iterator = tf.data.Iterator.from_structure(dataset.output_types,
                                            dataset.output_shapes)

print(type(iterator))


def getBatch():
    sample = iterator.get_next()
    print(sample)
    user = sample['user']
    item = sample['item']
    return user,item


users,items = getBatch()
usersum = tf.reduce_mean(users,axis=-1)
itemsum = tf.reduce_mean(items,axis=-1)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for i in range(2):
        sess.run(iterator.make_initializer(dataset))
        try:
            while True:
                print(sess.run([usersum,itemsum]))
        except tf.errors.OutOfRangeError:
            print("outOfRange")
上一篇 下一篇

猜你喜欢

热点阅读