我爱编程

TensorFlow1.5 新年全新教程

2018-01-31  本文已影响249人  LucasJin

本文介绍 TensorFlow1.5 新年全新教程(系列)

TensorFlow1.5 新年全新教程(系列)

This article was original written by Jin Tian, welcome re-post, first come with https://jinfagang.github.io . but please keep this copyright info, thanks, any question could be asked via wechat: jintianiloveu

很久没有更博客了,眨眼都已经2018年了,遥想去年跨年就好像发生在前天一样,预祝大家2019年猪年大吉。

闲话不多说。在家呆久了不学点东西感觉心虚,科技发展这么快,不脚踏实地开疆拓土怎么行呢?新年就要有新气象嘛,作为一位人工智能行业从业者,希望以一个过来的人的身份,带领更多的人在这条道路上披荆斩棘,开拓新的领域。工欲善其事必先利其器,TensorFlow1.5都已经发布了,我们还有什么理由不去学习一下最新的tf.data.Dataset API? 还有什么理由不期待一下TensorFlow Lite的终极版本以及专属于移动端的模型存储框架FlatBuf…感觉科技又前进了一个世纪,不过没有关系。凡事都得从当下做起。自从1.5版本发布 之后,tensorflow里面的很多API都将冻住了,并且会越来越规范化,为的正式迎接2018年深度学习应用落地的爆发之年。

闲话就说到这里了。我们首先从tensorflow的最新dataset API说起。

开始之前给大家安利一个工具:alfred, 专门为深度学习打造的工具,欢迎大家star, fork,enhance。我们接下来用它来随时爬几张猪啊狗啊的图片。

tf.data.Dataset

这个以前是在contrib下面的一个接口,现在放到了data下面,可以说是非常正统的tensorflow数据导入接口了。以前都是用tfrecords,现在不管是从单张图片,从文件夹路径,还是从numpy array类型的数据,都非常方便了。

假设我们有一个图片分类的简单任务。我们的目录是这样的:

-data
    |-dog
    |-pig
    |-...

这个猪啊狗啊的图片alfred可以帮你爬取:

sudo pip3 install alfred-py
alfred scrap image -q 'dog'
alfred scrap image -q 'pig'

每个类别装了许多同一类的图片。那直接读取到python的list,然后转成tensor,通过tf.data.Dataset就可以读入到tensorflow里面。

import tensorflow as tf
import os


NUMC_CLASSES = 2


def load_image():
    train_dir = 'data'
    all_classes = []
    all_images = []
    all_labels = []
    for i in os.listdir(train_dir):
        current_dir = os.path.join(train_dir, i)
        if os.path.isdir(current_dir):
            all_classes.append(i)
            for img in os.listdir(current_dir):
                if img.endswith('png') or img.endswith('jpg'):
                    all_images.append(os.path.join(current_dir, img))
                    all_labels.append(all_classes.index(i))
    return all_images, all_labels, all_classes


def train():
    all_images, all_labels, all_classes = load_image()
    print(all_classes)
    # convert all images list to tensor, using Dataset API to load
    train_data = tf.data.Dataset.from_tensor_slices((tf.constant(all_images), tf.constant(all_labels)))
    iterator = tf.data.Iterator.from_structure(train_data.output_types, train_data.output_shapes)

    next_elem = iterator.get_next()
    train_init_op = iterator.make_initializer(train_data)

    with tf.Session() as sess:
        sess.run(train_init_op)
        while True:
            try:
                print(sess.run(next_elem))
            except tf.errors.OutOfRangeError:
                print('data iterator finish.')
                break

if __name__ == '__main__':
    train()

我们可以看到输出结果是:

['dog', 'pig']
(b'dog_00.jpg', 0)
(b'dog_01.jpg', 0)
(b'pig_00.jpg', 1)
(b'pig_01.jpg', 1)
(b'pig_010.jpg', 1)
(b'pig_02.jpg', 1)
(b'pig_03.jpg', 1)
(b'pig_04.jpg', 1)
(b'pig_05.jpg', 1)
(b'pig_06.jpg', 1)
(b'pig_07.jpg', 1)
(b'pig_08.jpg', 1)
(b'pig_09.jpg', 1)
data iterator finish.

图片和标签都已经获得。用最新的Dataset API中的 from_tensor_slices可以非常方便的从list中将数据导入。

很多时候我们都需要对图片进行预处理,比如我们需要做一个检测数据集,我们要读入label和bbox,这个时候label需要one-hot,我们就需要对这个东西进行预处理,这个时候map就有用了。

tf.data.Dataset.map

这还没有完,我们的目的是操作每一张图片,做一些变换。或者对label进行一些处理,比如one-hot。在最新的dataset API中也有map函数进行操作。可以在这个map方法里,指定所有应有的操作。

def input_map_fn(img_path, label):
    # do some process to label
    one_hot = tf.one_hot(label, NUMC_CLASSES)
    img_f = tf.read_file(img_path)
    img_decodes = tf.image.decode_image(img_f, channels=3)
    return img_decodes, one_hot

然后将train_data加上即可。

    train_data = train_data.map(input_map_fn)

最终我们可以看到熟悉的,图片值 + one_hot label的训练数据。如果是对于像多标签分类,目标检测这样的任务label,也是做同样的处理。只要能保证前期的输入能在后期的网络中拿到就行了。

好了,现在tensorflow全新的数据导入API应该已经融会贯通了。下一篇大家等待更新,博主这还得去乡下拜个年。

上一篇下一篇

猜你喜欢

热点阅读