MxNet源码解析(5) io
2018-09-14 本文已影响0人
Junr_0926
1. 前言
数据的读取很大程度上决定了代码运行的快慢。
先从python端的data io开始。在MXNet的python文件夹中,io.py
定义了io需要的函数,其中_init_to_moduel
会在我们import该包的时候运行,它的函数体如下:
def _init_io_module():
"""List and add all the data iterators to current module."""
plist = ctypes.POINTER(ctypes.c_void_p)()
size = ctypes.c_uint()
check_call(_LIB.MXListDataIters(ctypes.byref(size), ctypes.byref(plist))) # 查看一共注册了多少个迭代器
module_obj = sys.modules[__name__]
for i in range(size.value):
hdl = ctypes.c_void_p(plist[i])
dataiter = _make_io_iterator(hdl)
setattr(module_obj, dataiter.__name__, dataiter) # 将迭代器注册到包名下
该函数会将注册的迭代器都在io的名字空间下注册。注册之前,会使用MXDataIter
进行包装,也就是说当我们在调用例如ImageRecordIter
的方法的时候,其实调用的是MXDataIter
的方法。
在初始化的时候,会调用Init
方法,在调用next
的时候,会调用对应的Next
方法,在调用reset
的时候会调用BeforeFirst
。
2. data.h
定义了迭代器的接口。
-
DataIter
:定义了数据迭代器的接口-
BeforeFirst
开始迭代之前的准备,reset
调用。 -
Next
返回下一个数据 -
Value
返回当前数据
-
class Row
3. io.h
MXnet中的迭代器
-
IIterator
: -
DataInst
:单个数据的结构体表示 -
DataBatch
:NDArray的DataBatch表示,由Iterator返回
4. augmenter
MXNet中定义了数据增强的迭代器,在源码中主要定义了两个增强的方法,一个是image_aug_default.cc
,一个image_det_aug_default
。
4.1 image_aug_default.cc
-
struct DefaultImageAugmentParam : public dmlc::Parameter<DefaultImageAugmentParam>
:定义了augmentation需要的参数,Parameter
是dmlc中定义的一个基础结构体,用于定义参数。在文件parameter.h
中定义了该结构体,需要了解的是DMLC_REGISTER_PARAMETER, DMLC_DECLARE_PARAMETER, DMLC_DECLARE_FIELD
的作用-
DMLC_DECLARE_PARAMETER
:给定parameter的名称,用于定义这个parameter,该宏会创建一个静态函数:static ::dmlc::parameter::ParamManager *__MANAGER__();
,这个__MANAGER__
函数会返回一个类型为ParamManager
的变量。接着,宏会定义另一个函数:__DECLARE__
,函数体在后面的括号中。 -
DMLC_DECLARE_FIELD
:用于创建参数的一个域,也就是一个具体的参数。调用方法:this->DECLARE(manager, #FieldName, FieldName)
,这个方法的定义如下:
-
template<typename DType>
inline parameter::FieldEntry<DType>& DECLARE(
parameter::ParamManagerSingleton<PType> *manager,
const std::string &key, DType &ref) { // NOLINT(*)
parameter::FieldEntry<DType> *e =
new parameter::FieldEntry<DType>();
e->Init(key, this->head(), ref);
manager->manager.AddEntry(key, e);
return *e;
}
-
DMLC_REGISTER_PARAMETER
:用于注册这个参数,它会定义__MANAGER__
函数
(//TODO)
- 类
DefaultImageAugmenter
定义了数据增强的方法,它继承自ImageAugmenter
5. iter_image_recordio.cc
5.1 Init
初始化函数会首先初始化未定义的参数,然后使用参数初始化parser_
。接着会初始化iter_
,它的定义:dmlc::ThreadIter<std::vector<InstVector<DType>> > iter_;
6. inst_vector.h
TensorVector
:由一组Tensor组成的vector,支持不同shape
InstVector
:由一组(label, example)组成的列表
7. threaditer.h
定义了类:ThreadedIter
,在该类的内部定义了内部类:Producer
。
-
Init
: