MxNet源码解析(2) symbol
1. 前言
我们在训练之前,先建立好一个图,然后我们可以在这个图上做我们想做的优化,这种形式称为Symbolic Programs。相对应的是Imperative Programs,也就是每一句代码都对应着程序的执行,在这种情况下,我们可以写类似于下面的代码:
a = 2
b= a + 1
d = np.zeros(10)
for i in range(d):
d += np.zeros(10)
这在symbolic的方式下是做不到的,因为在for循环开始时,程序并不知道d
的值,也就无法判断循环的次数。
因此我们可以说,symbolic更高效,imperative更灵活。
MxNet是一个异步式的训练框架,它支持上面的两种形式。我们可以使用NDArray
来进行imperative形式的程序编写,也可以使用symbol
来建立图。
2. op
先来了解operator
,不了解operator
可能就很难理解源码中占据了很大一部分的operator的定义。就是通过这些operator来将symbol连接成为了一个图。
-
OpManager
:单例结构体,通过OpManager::Global()
总会返回同一个结构体。Op的构造函数会将OpManager
的op_counter
加一,并且将自己的index_
注册为当前的op_counter
。 -
add_alias
:将别名注册到`dmlc::Registry<Op>中 -
Get
:根据name
返回Op
GetAttrMap
2.1 op
-
name
:名字 -
description
:该op的描述 -
num_inputs
:输入的个数 -
num_outputs
:输出的个数 -
get_num_outputs, get_num_inputs
:函数,返回输出,输入的个数 -
attr_parser
:函数,用于方便返回该op的参数 -
Op& Op::describe(const std::string& descr)
:方法用于将输入注册到description变量中,并返回这个op,方便接着调用其他方法。
2.2 几个宏
-
#define NNVM_REGISTER_VAR_DEF(OpName)
:定义OpName -
#define NNVM_REGISTER_VAR_DEF(TagName)
:定义TagName
#define NNVM_REGISTER_OP(OpName) \
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName, __COUNTER__) = \
::dmlc::Register<::nnvm::op>::Get()->__REGISTER_OR_GET(#OpName)
注册op,并返回该op
3. Node
Node是组成symbol的基本组件。
结构体NodeEntry
包含了:
-
node
:指向node的指针 -
index
:输出的索引值 -
version
:输入的version
结构体NodeAttrs
包含了:
-
op
: 指向operator的指针 -
name
: node的名字 -
dict
:attributes的字典
类Node
包含:
-
attrs
:结构体NodeAttrs
成员,存储了op, name, attributes
等信息。 -
inputs
:输入,是一个元素为NodeEntry
的向量 -
control_deps
:保存了应该在该node执行之前执行的node。 -
op()
:返回该Node的operator,就是返回attrs
中保存的op
-
Create()
:类方法,静态方法,用于新建一个Node,返回指向它的指针 -
num_outputs
:如果是变量,输出为1,否则返回op
的输出
几个函数
定义在文件op_attr_types.h
中
-
FListinputNames
:返回输入的名字,默认return {'data'}
-
FNumVisibleOutputs
:用于隐藏一些输出 -
FListOutputNames
:返回输出的名字 -
FMutateInputs
:返回该node会改变的node的索引值 -
FInferNodeEntryAttr
:推理出AttrType
-
FInferShape
:推理shape,也就是上面的AttrType
为Tshape
-
FInferType
:推理类型 -
TIsBackward
是否是反向传播 FInplaceOption
-
FGradient
:返回node的梯度节点 -
FSetInputVarAttrOnCompose
:为输入设置attribute -
FCorrectLayout
:推理layout -
FInputGraph
:返回输入,解释为图而不是数据
这些函数是在定义具体的op时,可以选择注册对应的函数。
4. Symbol
Symbol是为了使用Node建立Graph。Symbol是我们能够直接接触的类,它定义了一系列方法用于更方便地构建图。在symbol的成员outputs
中,定义了一组由NodeEntry
组成的向量。
-
outputs
:该symbol包含的输出,是一个元素是NodeEntry
的向量 -
Copy
:返回一个深拷贝,方式是通过遍历Node,每次访问到的Node保存起来,再建立起node之间的连接,最后将head加入到outputs中。 -
Symbol operator[] (size_t index) const
:返回第个输出。 -
ListInputs
:返回输入 -
ListInputNames
:返回输入的名字 -
Compose
:组合symbol -
operator ()
:调用compose,来组合symbol -
AddControlDeps
:加入控制,用于有向图的构建 -
GetInternals
:返回一个symbol,它的输出是原来symbol的输出加上所有中间输出和输入 -
GetChildren
: -
SetAttrs
:设置attribution -
GetAttrs
: -
CreateFunctor
:给定op和attrs,返回一个symbol
我认为symbol
中比较重要的函数是compose,在调用的时候我们是通过调用symbol
的操作符()
函数,也就是operator ()
,该函数将参数传递给Compose
。
5. Graph
类Graph
就是计算的时候使用的图
-
outputs
:和symbol
的outputs
一样,类型为std::vector<NodeEntry>
-
attrs
:定义了图的一些属性 -
PostOrderDFSVisit
:后序遍历图,给定参数head,进行拓扑排序。算法,貌似,就是拓扑排序算法。 -
DFSVisit
:调用PostOrderDFSVisit
,对图的head进行拓扑排序。参数为:const std::vector<NodeEntry>& heads, FVisit fvisit
,其中head
是反向传播时的头节点,fvisit
是访问时调用的函数,该方法将fvisit(*n)
作为访问节点时的函数,[](GNode n)->Node*{return->get();}
作为hash函数,这个函数看签名返回的是一个指向节点的指针。图的节点入度计算如下:
[](GNode n)->uint32_t {
if (!(*n)) return 0;
return (*n)->input.size() + (*n)->control_deps.size();
}
节点输入计算如下:
[](GNode n, uint32_t index)->GNode {
if (index < (*n)->input.size()) {
return &(*n)->input.at(index).node;
} else {
return &(*n)->contorl_deps.at(index - (*n)->inputs.size());
}
6. IndexedGraph
IndexedGraph
由Graph
返回,
-
nodes_
:成员变量,一个指向Node
结构体的向量,Node
定义如下:
struct Node {
const nnvm::Node* source;
array_view<NodeEntry> inputs;
array_view<uint32_t> control_deps;
std::weak_ptr<nnvm::Node> weak_ref;
};
其中NodeEntry
如下:
struct NoodeEntry {
uint32_t node_id;
uint32_t index;
uint32_t version;
};
成员变量:
-
input_nodes_
:输入node的索引 mutable_input_nodes_
-
outputs
:输出节点 -
node2index
:node到索引的映射 -
entry_rptr_
: -
input_entries_
: -
control_deps_
:
方法: DFSVisit
PostOrderDFSVisti
7. pass
7.1 gradient.cc
-
Gradient
:gradient
会根据属于的graph
,返回一个带反向传播图的新图。它主要由executor
建立图的时候调用,调用方式如下:
nnvm::Graph g_grad = nnvm::pass::Gradient(g,
symbol.outputs, xs, head_grad_entry_, ArggregateGradient,
need_mirror, nullptr, zero_ops, "_copy");
调用该方法会调用文件pass_function.h
下的Gradient
函数。该函数将传入的参数保存在graph
下的attrs
中。再通过applypass
调用Gradient方法。也就是在该文件下定义的方法,签名:Graph Gradient(Graph src)
。
- 根据DFSVisit进行拓扑排序,将序列存储到
topo_order
中 - 将输出的梯度保存在
output_grads
- 根据
mirror_fun
在适当的地方插入新的节点,来实现内存的复用
-
DefaultAggregateGradient
: