1.mxnet代码学习1——Trainer

2017-12-08  本文已影响216人  吃个小烧饼

Trainer类是gluon下的一个类。顾名思义,就是主导“训练”的一个驱动方法。可以说不管写出多炫酷的网络结构,都需要从这个对象开始训练。可以从它的初始化方法__init__()看到一些基本信息:
def __init__(self, params, optimizer, optimizer_params=None, kvstore='device'):
忽略最后一个参数kvstore,我们可以知道其需要的几个参数是:params, optimizer, optimizer_params

先说params,其注释部分有说:params : ParameterDict The set of parameters to optimize.,就是需要优化的参数啦。举个例子,一个cnn里,就是每个卷积核的那些参数。这个params如果获取到的是一个ParameterDict,我们就需要将其转化为一个listparams = list(params.values()),随后就把它写为一个内部 list 对象_params。所以说,我感觉在写工业的代码时,要充分将收集到的对象规范化,一般我自己在写算法代码时从来不考虑这些。

所以,要是你的参数很容易得到,层数也不多,完全可以手动添加到一个list里,在mxnet的官方中文教程里就有这样的示范:

params = [W1, b1, W2, b2, W3, b3, W4, b4]
for param in params:
    param.attach_grad()

就是收集到我们要更新的参数,然后记录它的梯度用以更新。既然说到这里了就可以看一下在这里如何更新:由于attach_grad对梯度创建一个“占位”,当执行完计算输出->输出和已有的做交叉熵计算loss之后,需要执行mxnet自带的backword来自动计算梯度。当然,具体mxnet(以后简写为mx)是如何计算的,它是拿cpp写的,有时间的话我也学习一下。当然,做完后向传播之后还要更新一下你的参数,比如SGD就是

def SGD(params, lr):
    for param in params:
        param[:] = param - lr * param.grad

说了这么多,好像越扯越远。确实,从想要说明参数表可以是 list rather than dict,花费了我大量不需要的时间······

回到正题。当我们有了_params之后,我们就需要来解析优化器optimizer了。在我们一般使用时,往往(或者绝对)输入的是一个字符串'sgd'这样的,然而实际上这里要配合前后2个参数将其转化为一个Optimizer类对象。这个Optimizer产生的方式、也就是直接写的话,是这样的:

sgd = mx.optimizer.Optimizer.create_optimizer('sgd',param_dict=param_dict, learning_rate=.1)

然后把这个对象返回个内部私有对象_optimizer。而在传递param_dict之前,首先要使用_params建立一个词典,方法也很简单:

param_dict = {i: param for i, param in enumerate(self._params)}

这里就要说到最重要的地方了。上边讲到的create_optimizer是这样声明的:

def create_optimizer(name, **kwargs):

name是一个str,**kwargs是字典形式的参数表,返回一个optimizer类对象,最骚的来了,这里的对象在创建的时候给一个叫opt_registry的词典以name
key ,以返回的类为 value 添加一项。这一部分在函数def register(klass):中。

而包含他们的文件optimizer.py 里有以register为装饰器执行register,所以在执行create之前必然先创建了以输入的name为名的对象。
这个装饰器的位置在类Optimizer之后,在所有优化方法之前。所以我们的操作就是:

所以当我们具体调用一个Trainer时,背后的操作如下:

trainer=gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':.1})

1.进入optimizer.py包,执行register = Optimizer.register对象化register()方法。

2.register函数对象为装饰器,遍历下边的具体优化方法类SGD Adam等,也就相当于register(SGD)等等方法。
执行此函数时创建Optimezer类,其中有一个字典opt_registry,我们把具体优化方法类的名字的小写作为opt_registry的一项的 key ,具体的优化方法类作为 value ,构建一个字典。

3.初始化函数对象createcreate = Optimizer.create_optimizer。噼里啪啦一顿操作后完成,模块读取完毕。

4.回到有trainer的模块,收集参数net.collect_params(),将其构建为字典,然后Trainer执行主方法,调用刚刚生产的create,以参数字典为一个参数,字符串sgdname,具体优化参数字典为**kwargs
具体的create过程就是,检查字符串name是否在刚刚构建的词典opt_registry内,在的话匹配,即初始化与其匹配的优化方法类:

Optimizer.opt_registry[name.lower()](**kwargs)

5.结束了,返回 >_<

上一篇下一篇

猜你喜欢

热点阅读