mxnetMXNet

MxNet源码解析(2) symbol

2018-09-14  本文已影响0人  Junr_0926

1. 前言

我们在训练之前,先建立好一个图,然后我们可以在这个图上做我们想做的优化,这种形式称为Symbolic Programs。相对应的是Imperative Programs,也就是每一句代码都对应着程序的执行,在这种情况下,我们可以写类似于下面的代码:

a = 2
b= a + 1
d = np.zeros(10)
for i in range(d):
    d += np.zeros(10)

这在symbolic的方式下是做不到的,因为在for循环开始时,程序并不知道d的值,也就无法判断循环的次数。
因此我们可以说,symbolic更高效,imperative更灵活。

MxNet是一个异步式的训练框架,它支持上面的两种形式。我们可以使用NDArray来进行imperative形式的程序编写,也可以使用symbol来建立图。

2. op

先来了解operator,不了解operator可能就很难理解源码中占据了很大一部分的operator的定义。就是通过这些operator来将symbol连接成为了一个图。

2.1 op

2.2 几个宏

#define NNVM_REGISTER_OP(OpName) \
  DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName, __COUNTER__) = \
    ::dmlc::Register<::nnvm::op>::Get()->__REGISTER_OR_GET(#OpName)

注册op,并返回该op

3. Node

Node是组成symbol的基本组件。
结构体NodeEntry包含了:

结构体NodeAttrs包含了:

Node包含:

几个函数

定义在文件op_attr_types.h

这些函数是在定义具体的op时,可以选择注册对应的函数。

4. Symbol

Symbol是为了使用Node建立Graph。Symbol是我们能够直接接触的类,它定义了一系列方法用于更方便地构建图。在symbol的成员outputs中,定义了一组由NodeEntry组成的向量。

5. Graph

Graph就是计算的时候使用的图

[](GNode n)->uint32_t {
  if (!(*n)) return 0;
  return (*n)->input.size() + (*n)->control_deps.size();
}

节点输入计算如下:

[](GNode n, uint32_t index)->GNode {
  if (index < (*n)->input.size()) {
    return &(*n)->input.at(index).node;
  } else {
  return &(*n)->contorl_deps.at(index - (*n)->inputs.size());
}

6. IndexedGraph

IndexedGraphGraph返回,

struct Node {
  const nnvm::Node* source;
  array_view<NodeEntry> inputs;
  array_view<uint32_t> control_deps;
  std::weak_ptr<nnvm::Node> weak_ref;
};

其中NodeEntry如下:

struct NoodeEntry {
  uint32_t node_id;
  uint32_t index;
  uint32_t version;
};

成员变量:

7. pass

7.1 gradient.cc

nnvm::Graph g_grad = nnvm::pass::Gradient(g, 
            symbol.outputs, xs, head_grad_entry_, ArggregateGradient,
            need_mirror, nullptr, zero_ops, "_copy");

调用该方法会调用文件pass_function.h下的Gradient函数。该函数将传入的参数保存在graph下的attrs中。再通过applypass调用Gradient方法。也就是在该文件下定义的方法,签名:Graph Gradient(Graph src)

  1. 根据DFSVisit进行拓扑排序,将序列存储到topo_order
  2. 将输出的梯度保存在output_grads
  3. 根据mirror_fun在适当的地方插入新的节点,来实现内存的复用

7.2 plan_memory.cc

7.3 place_device.cc

7.4 correct_layout.cc

上一篇下一篇

猜你喜欢

热点阅读