MXNETTensorFlow人工智能之路 - 不停顿的蜗牛

[MXnet] 如何将数据集加载到MXnet中

2016-05-27  本文已影响1976人  ToeKnee

MXnet的学习笔记,这次主要是使用MXnet提供的example模型进行训练时如何加载数据集的介绍。步骤基本上按照MXNet Python Data Loading API
有关MXnet在OSX下的编译安装,可以看这里Mac下编译安装MXNet
有关MXnet提供的example的综述介绍<-在这里。

Sample iterator for data loading


在浏览完MXnet提供的example后想要在自己的机器上跑一下简单的数据集看看结果。因为现在只是装在自己的MBA上,没有装CUDA和OpenMP,也没有使用GPU训练,因此只能跑一跑简单的数据集。MXnet的Image Classification Example中的样例都比较完整,使用步骤也很详细,训练最基本的MNIST数据集基本上不需要多余的工作量,只要能联网下载MNIST数据集(或者自己有数据集的话移动到对应文件夹下)就可以直接训练,效果也挺不错:

→ python train_mnist.py 
2016-05-23 08:51:41,616 Node[0] start with arguments Namespace(batch_size=128, data_dir='mnist/', gpus=None, kv_store='local', load_epoch=None, lr=0.1, lr_factor=1, lr_factor_epoch=1, model_prefix=None, network='mlp', num_epochs=10, num_examples=60000, save_model_prefix=None)
[08:51:45] src/io/iter_mnist.cc:91: MNISTIter: load 60000 images, shuffle=1, shape=(128,784)
[08:51:46] src/io/iter_mnist.cc:91: MNISTIter: load 10000 images, shuffle=1, shape=(128,784)
2016-05-23 08:51:46,460 Node[0] Start training with [cpu(0)]
...
2016-05-23 08:52:02,548 Node[0] Epoch[9] Batch [450]    Speed: 41054.59 samples/sec Train-top_k_accuracy_20=1.000000
2016-05-23 08:52:02,605 Node[0] Epoch[9] Resetting Data Iterator
2016-05-23 08:52:02,605 Node[0] Epoch[9] Time cost=1.470
2016-05-23 08:52:02,750 Node[0] Epoch[9] Validation-accuracy=0.977464
2016-05-23 08:52:02,750 Node[0] Epoch[9] Validation-top_k_accuracy_5=0.999299
2016-05-23 08:52:02,750 Node[0] Epoch[9] Validation-top_k_accuracy_10=1.000000
2016-05-23 08:52:02,750 Node[0] Epoch[9] Validation-top_k_accuracy_20=1.000000

默认的参数为:batch-size=128,初始学习率为0.1(固定学习率,lr_factor_epoch=1),使用最基本的多层感知机MLP进行训练。每个epoch耗时大约1.5秒左右,在10次迭代后测试集的accuracy达到0.977464。
其实在测试的时候看到Validation-accuracy时就有在想指的是cross-validation的accuracy还是test的accuracy,因此这时候就可以先去看看MXnet中到底是怎么读取数据、怎么使用KVstore的。

根据官方文档的介绍,MXnet使用iterator将参数传递给训练模型。这里的iterator会做一些数据预处理,并且生成指定大小的batch输入训练模型。
由于MNIST的数据比较简单,example里面提供了载入MNIST数据集的iterator实现,如下:

def get_iterator(data_shape):
    def get_iterator_impl(args, kv):
        data_dir = args.data_dir
        # 若指定位置没有MNIST数据集则会调用_download()函数联网下载
        if '://' not in args.data_dir:
            _download(args.data_dir)
        # data_shape变量为输入数据的格式。对于MNIST:
        # 若使用MLP进行训练,输入数据为有784个元素的一维向量,data_shape = (784, )
        # 若使用LeNet进行训练,输入数据为一个28*28的矩阵,data_shape = (1, 28, 28)
        # 因此若len(data_shape)不等于3时,设置flat变量为True,即对MNIST每一个输入数据一维扁平化
        flat = False if len(data_shape) == 3 else True

        # 训练集的参数指定
        train           = mx.io.MNISTIter(
            image       = data_dir + "train-images-idx3-ubyte",
            label       = data_dir + "train-labels-idx1-ubyte",
            input_shape = data_shape,
            batch_size  = args.batch_size,
            ## A commonly mistake is forgetting shuffle the image list during packing.
            ## This will lead fail of training.
            ## eg. accuracy keeps 0.001 for several rounds.
            shuffle     = True,
            flat        = flat,
            num_parts   = kv.num_workers,
            part_index  = kv.rank)

        # 测试集的参数指定
        val = mx.io.MNISTIter(
            image       = data_dir + "t10k-images-idx3-ubyte",
            label       = data_dir + "t10k-labels-idx1-ubyte",
            input_shape = data_shape,
            batch_size  = args.batch_size,
            flat        = flat,
            num_parts   = kv.num_workers,
            part_index  = kv.rank)

        return (train, val)
    return get_iterator_impl

train_mnist.py 的main函数里,会调用get_iterator()函数得到输入的iterator,传递给train_model.fit()函数执行真正的训练过程。
在之前的example介绍里有说到,Image Classification(包括后面基于CNN的很多其它网络)的不同网络结构运用在不同的数据集上,最后都是回到调用train_model.fit()函数进行训练。因此输入数据的获取和iterator的定义都在对应的 train_{mnist, cifar10, imagenet}.py 中,最简单的定义就如上面的代码所示。

Build your own iterator


MNIST输入数据的格式类型分为recordio,MNISTcsv。MNIST数据集的参数指定较为简单,上面的例子基本都覆盖到了。有关csv和MNIST数据集的更多参数指定信息<--点击链接。
对于图片数据集(recordio格式的数据),在创建iterator时,一般需要指定的参数有五类,包括:

  1. 数据集参数 (Dataset Param),提供了数据集的基本信息,如数据文件地址、数据形状(即前例中的input_shape)等等。
  2. 批参数 (Batch Param) 提供了形成batch的信息,比如batch size
  3. Augmentation Param 可以设定对数据集预处理的参数,比如mean_image(将图像中的每个像素减去图片像素均值),rand_crop(随机对图像进行部分切割),rand_mirror(随机对图像进行水平对称变换)等等。
  4. 后台参数 (Backend Param) 控制后台线程来隐藏读取数据的开销的相关参数,如preprocess_threads设定后台预读取线程数量,prefetch_buffer设定预读取buffer的大小。
  5. 辅助参数 (Auxiliary Param) 提供用于调试的参数设定,如verbose设定是否要输出parser信息。

具体的参数定义可以看官方文档:I/O API

Use your own data


要使用自己的数据集(或者ImageNet数据集),由于MXnet没有提供类似MNIST和cifar的自动下载和加载脚本将原始数据转换为ImageRecord数据,因此需要自己进行数据格式转换。
不过将数据转换为ImageRecord格式也很简单:

integer_image_index \t label_index \t path_to_image

make_list接受的参数包括

这边会遇到一点问题,如果调用im2rec.py的时候提示

No module named cv

的话,网上查询到的原因是没安装openCV(不过其实之前装了……)
那只要把代码中

import cv, cv2

中的cv去掉即可,后续好像只使用到了cv2库中的内容,不需要cv。

然后就可以在MXnet中使用自己的数据集了。

上一篇下一篇

猜你喜欢

热点阅读