mxnet源码分析1

2017-12-30  本文已影响0人  迷途的Go

如何使用

from mxnet import nd
from mxnet.gluon import nn
from mxnet import gluon
from mxnet import autograd

class Net(nn.Block):
    def __init__(self, **kwargs):
        super(Net, self).__init__(**kwargs)
        self.dense0 = nn.Dense(4, use_bias=False)
        self.dense1 = nn.Dense(2, use_bias=False)

    def forward(self, x):
        return self.dense1((self.dense0(x)))
def train():
    net = Net()
    net.initialize()
    w = net.dense0.weight
    print ('weight shape after initialize', w.shape, 'weight params', w.data())
    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 1})
    data = nd.ones(shape=(1, 1, 28, 28))
    label = nd.ones(shape=(1, 10))
    loss = gluon.loss.L2Loss()
    with autograd.record():
        res = net(data)
        w = net.dense0.weight
        print ('net[0] name', net.dense0.name, 'weight shape', w.shape, '\nparams', w.data(), 'grad', w.grad())
        L = loss(res, label)
    L.backward()
    trainer.step(batch_size=1)
    print ('net[0] name', net.dense0.name, 'weight shape', w.shape, '\nparams', w.data(), 'grad', w.grad())
if __name__ == '__main__':
    train()

执行输出结果:

('weight shape after initialize', (4, 0))
('net[0] name', 'dense0', 'weight shape', (4L, 784L), '\nparams', 
[[ 0.04118239  0.05352169 -0.04762455 ...,  0.03089482 -0.00140258
   0.01266012]
 [-0.00697319 -0.00986735 -0.03128323 ...,  0.02195714 -0.04105704
   0.01050965]
 [ 0.02380178 -0.04182156  0.04908523 ..., -0.05005977 -0.0463761
   0.0436078 ]
 [-0.04813539 -0.03545294 -0.01216894 ...,  0.06526501 -0.00576673
  -0.02751607]]
<NDArray 4x784 @cpu(0)>, 'grad', 
[[ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]]
<NDArray 4x784 @cpu(0)>)
('net[0] name', 'dense0', 'weight shape', (4L, 784L), '\nparams', 
[[ 0.02016377  0.03250307 -0.06864318 ...,  0.00987619 -0.0224212
  -0.00835851]
 [-0.05362909 -0.05652324 -0.07793912 ..., -0.02469876 -0.08771293
  -0.03614624]
 [ 0.0333778  -0.03224555  0.05866124 ..., -0.04048375 -0.03680009
   0.05318382]
 [-0.03410936 -0.02142691  0.00185709 ...,  0.07929104  0.00825929
  -0.01349004]]
<NDArray 4x784 @cpu(0)>, 'grad', 
[[ 0.02101862  0.02101862  0.02101862 ...,  0.02101862  0.02101862
   0.02101862]
 [ 0.04665589  0.04665589  0.04665589 ...,  0.04665589  0.04665589
   0.04665589]
 [-0.00957601 -0.00957601 -0.00957601 ..., -0.00957601 -0.00957601
  -0.00957601]
 [-0.01402603 -0.01402603 -0.01402603 ..., -0.01402603 -0.01402603
  -0.01402603]]
<NDArray 4x784 @cpu(0)>)

以上代码包含了一个神经网络的典型结构:

上面的代码证明了两样东西:

以上涉及了gluon的关键组件:

上一篇 下一篇

猜你喜欢

热点阅读