mxnet源码分析1
2017-12-30 本文已影响0人
迷途的Go
如何使用
from mxnet import nd
from mxnet.gluon import nn
from mxnet import gluon
from mxnet import autograd
class Net(nn.Block):
def __init__(self, **kwargs):
super(Net, self).__init__(**kwargs)
self.dense0 = nn.Dense(4, use_bias=False)
self.dense1 = nn.Dense(2, use_bias=False)
def forward(self, x):
return self.dense1((self.dense0(x)))
def train():
net = Net()
net.initialize()
w = net.dense0.weight
print ('weight shape after initialize', w.shape, 'weight params', w.data())
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 1})
data = nd.ones(shape=(1, 1, 28, 28))
label = nd.ones(shape=(1, 10))
loss = gluon.loss.L2Loss()
with autograd.record():
res = net(data)
w = net.dense0.weight
print ('net[0] name', net.dense0.name, 'weight shape', w.shape, '\nparams', w.data(), 'grad', w.grad())
L = loss(res, label)
L.backward()
trainer.step(batch_size=1)
print ('net[0] name', net.dense0.name, 'weight shape', w.shape, '\nparams', w.data(), 'grad', w.grad())
if __name__ == '__main__':
train()
执行输出结果:
('weight shape after initialize', (4, 0))
('net[0] name', 'dense0', 'weight shape', (4L, 784L), '\nparams',
[[ 0.04118239 0.05352169 -0.04762455 ..., 0.03089482 -0.00140258
0.01266012]
[-0.00697319 -0.00986735 -0.03128323 ..., 0.02195714 -0.04105704
0.01050965]
[ 0.02380178 -0.04182156 0.04908523 ..., -0.05005977 -0.0463761
0.0436078 ]
[-0.04813539 -0.03545294 -0.01216894 ..., 0.06526501 -0.00576673
-0.02751607]]
<NDArray 4x784 @cpu(0)>, 'grad',
[[ 0. 0. 0. ..., 0. 0. 0.]
[ 0. 0. 0. ..., 0. 0. 0.]
[ 0. 0. 0. ..., 0. 0. 0.]
[ 0. 0. 0. ..., 0. 0. 0.]]
<NDArray 4x784 @cpu(0)>)
('net[0] name', 'dense0', 'weight shape', (4L, 784L), '\nparams',
[[ 0.02016377 0.03250307 -0.06864318 ..., 0.00987619 -0.0224212
-0.00835851]
[-0.05362909 -0.05652324 -0.07793912 ..., -0.02469876 -0.08771293
-0.03614624]
[ 0.0333778 -0.03224555 0.05866124 ..., -0.04048375 -0.03680009
0.05318382]
[-0.03410936 -0.02142691 0.00185709 ..., 0.07929104 0.00825929
-0.01349004]]
<NDArray 4x784 @cpu(0)>, 'grad',
[[ 0.02101862 0.02101862 0.02101862 ..., 0.02101862 0.02101862
0.02101862]
[ 0.04665589 0.04665589 0.04665589 ..., 0.04665589 0.04665589
0.04665589]
[-0.00957601 -0.00957601 -0.00957601 ..., -0.00957601 -0.00957601
-0.00957601]
[-0.01402603 -0.01402603 -0.01402603 ..., -0.01402603 -0.01402603
-0.01402603]]
<NDArray 4x784 @cpu(0)>)
以上代码包含了一个神经网络的典型结构:
- 定义网络,上面是一个mlp
- 网络初始化
- 训练网络
- 前向传播
- 计算loss
- 反向传播得到梯度
- 更新权重参数
上面的代码证明了两样东西:
- 定义网络后初始化给出的第二维是0, 这个是由于mxnet参数初始化延迟推导, 不知道输入,没办法知道第二维参数,相比于pytorch,优点是不用定义每一层网络的输入大小,但是一次forward之前就不知道参数的形状了
- weight=weight-lr*grad, 以第一个参数为例,上述打印的结果前向传播的时候大小是0.04118239, 梯度是0, 一次反向传播后,梯度是0.02101862, 新的参数是0.02016377=0.04118239-0.02101862
以上涉及了gluon的关键组件:
- gluon.nn.Block,Sequential, HybridBlock, HybridSequential的父类
- loss
- gluon.Trainer, 用来辅助更新模型参数的辅助类
- mxnet.optimizer
- mxnet.nd