开始深度学习之旅-使用mxnet训练mnist数据集

2019-12-31  本文已影响0人  LonnieQ

源码链接

https://github.com/LoniQin/AwsomeNeuralNetworks/blob/master/trainer_v1.py

1. 导入依赖库

from mxnet.gluon import data as gdata
import time
from mxnet import autograd, nd

2. 定义常量

如下,num_inputs指的是输入参数的数量,num_outputs指的是输出参数的数量,batch_size是每次训练的批次数,num_epochs是训练的周期数,learning_rate是学习率。

num_inputs = 784
num_outputs = 10
batch_size = 256
num_workers = 4
num_epochs = 100
learning_rate = 0.1

3. 获取数据

mnist_train, mnist_test分别是mnist的训练集和测试集,首先会检测本地是否有mnist数据集,如果存在则加载本地;否则从网络下载保存到本地使用。train_iter和test_iter可以供给用户用for循环迭代进行训练。

mnist_train = gdata.vision.FashionMNIST(train=True)

mnist_test = gdata.vision.FashionMNIST(train=False)

transformer = gdata.vision.transforms.ToTensor()

train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size, shuffle=True, num_workers=num_workers)

test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size, shuffle=False, num_workers=num_workers)

4. 定义分类函数

在图像分类的过程中,假设有n类,那么输出需要有n个属性。我们要将这n个属性分别计算出概率并且选取最大可能性的类别。Softmax函数很好地处理这个问题。假设某一个数据集有n个类别,经过线性回归或其他方式计算出的输出值为o_1o_n, softmax的计算公式为:
softmax(o_i) = \frac{exp(o_i)}{\sum_{i=1}^n exp(o_i)}
很显然,它的输出个数等于标签的类别数,它会计算出每个类别的概率,并且它们的和为1。

python代码为:

def softmax(X):
    x_exp = X.exp()
    partition = x_exp.sum(axis = 1, keepdims = True)
    return x_exp / partition

5. 定义损失函数

损失函数用于衡量神经网络的性能。对于分类问题,交叉熵是常用的损失函数,可以衡量两个概率的分布差异。其计算公式为:
H(y^{(i)}, \hat{y}^{(i)}) = -\sum_{j=1}^qy_j^{(i)}logy_j^{(i)}
若训练数据集为n, 交叉熵损失函数定义为:
l(\theta) = \frac{1}{n}\sum_{i=1}^nH(y^{(i)}, \hat{y}^{(i)})

def cross_entropy(y_hat, y):
    return  -nd.pick(y_hat, y).log()

6 定义参数更新函数

在训练图像数据的过程中,我们首先会随机初始化一组权重数据W和偏差数据b,我们用这些数据通过神经网络计算出结果,跟训练数据的结果用损失函数进行比较,会得到一个差值。然后我们会用反向传播算法计算出W和b的梯度。然后用某种策略更新W和b的数据,使之在下一次训练数据中得到更接近答案的数据。随机梯度下降函数就有这样的作用,它会根据我们设置的学习率更新参数,其范围从0到1。
实现代码如下:

def sgd(params, lr, batch_size):
    for i in range(len(params)):
        nd.elemwise_sub(params[i], lr * params[i].grad / batch_size, out=params[I])

我们需要选择合适的学习率,如果我们学习率选的太小,那么需要很久才能够接近最优解;如果我们的学习率选的太大,那么很容易错过最优解。

7. 构建神经网络

我们通常用线性回归来计算分类问题,其表达式为:
y = XW + b
代码实现为:

def net(X, w, b):
    y = nd.dot(X.reshape((-1, w.shape[0])), w) + b
    return softmax(y)

这只是单层神经网络。在实际的训练过程中,我们需要构建几层甚至几百层神经网络。提高神经网络准确率的一个方向便是建立更深层的神经网络。

8. 训练数据

在这个例子中,我们首先初始化权重w和偏差参数b,并且开启记录梯度的功能,这里我们每批次训练batch_size256次,训练数据集个数为50000条,每个周期我们会进行196批的训练。我们这里的总周期是num_epochs = 100次。在每一批次的训练过程中,我们调用神经网络计算出y的预测值\hat{y},然后用损失函数计算出损失值l, 然后l进行反向传播,wb会因此得到梯度值,最终我们调用参数更新函数更新wb参数。

w = nd.random.normal(scale=1.0, shape=(num_inputs, num_outputs))
b = nd.zeros(num_outputs)
w.attach_grad()
b.attach_grad()
loss = cross_entropy
# Train models
start = time.time()
for epoch in range(1, num_epochs + 1):
    acc_sum, n = 0.0, 0
    total = float(len(mnist_train))
    for X, y in train_iter:
        with autograd.record():
            y_hat = net(X, w, b)
            l = loss(y_hat, y).sum()
        l.backward()
        sgd([w, b], learning_rate, batch_size)
        acc_sum += (y_hat.argmax(axis=1) == y.astype('float32')).sum().asscalar()
    print("Epoch:%d Elapsed time:%.2f accuracy:%.2f%%" % (epoch, time.time() - start, (acc_sum / total) * 100))

9. 运行结果

经过100周期的训练,预测准确率从44%到83%.

Epoch:1 Elapsed time:1.37 accuracy:44.04%
Epoch:2 Elapsed time:2.69 accuracy:62.63%
Epoch:3 Elapsed time:4.00 accuracy:67.57%
Epoch:4 Elapsed time:5.32 accuracy:70.34%
Epoch:5 Elapsed time:6.63 accuracy:72.17%
Epoch:6 Elapsed time:8.09 accuracy:73.39%
Epoch:7 Elapsed time:9.36 accuracy:74.33%
Epoch:8 Elapsed time:10.63 accuracy:75.06%
Epoch:9 Elapsed time:11.93 accuracy:75.83%
Epoch:10 Elapsed time:13.20 accuracy:76.27%
Epoch:11 Elapsed time:14.59 accuracy:76.70%
Epoch:12 Elapsed time:15.89 accuracy:77.01%
Epoch:13 Elapsed time:17.22 accuracy:77.41%
Epoch:14 Elapsed time:18.59 accuracy:77.81%
Epoch:15 Elapsed time:19.93 accuracy:78.02%
Epoch:16 Elapsed time:21.31 accuracy:78.27%
Epoch:17 Elapsed time:22.69 accuracy:78.50%
Epoch:18 Elapsed time:24.12 accuracy:78.70%
Epoch:19 Elapsed time:25.46 accuracy:78.91%
Epoch:20 Elapsed time:26.82 accuracy:78.98%
Epoch:21 Elapsed time:28.27 accuracy:79.24%
Epoch:22 Elapsed time:29.70 accuracy:79.34%
Epoch:23 Elapsed time:31.25 accuracy:79.60%
Epoch:24 Elapsed time:32.62 accuracy:79.64%
Epoch:25 Elapsed time:34.35 accuracy:79.73%
Epoch:26 Elapsed time:35.84 accuracy:79.92%
Epoch:27 Elapsed time:37.32 accuracy:80.06%
Epoch:28 Elapsed time:38.97 accuracy:80.14%
Epoch:29 Elapsed time:40.59 accuracy:80.23%
Epoch:30 Elapsed time:42.28 accuracy:80.45%
Epoch:31 Elapsed time:43.77 accuracy:80.52%
Epoch:32 Elapsed time:45.43 accuracy:80.58%
Epoch:33 Elapsed time:47.05 accuracy:80.69%
Epoch:34 Elapsed time:48.70 accuracy:80.80%
Epoch:35 Elapsed time:50.20 accuracy:80.86%
Epoch:36 Elapsed time:51.62 accuracy:80.90%
Epoch:37 Elapsed time:53.17 accuracy:81.02%
Epoch:38 Elapsed time:54.65 accuracy:81.08%
Epoch:39 Elapsed time:56.22 accuracy:81.21%
Epoch:40 Elapsed time:57.79 accuracy:81.25%
Epoch:41 Elapsed time:59.29 accuracy:81.34%
Epoch:42 Elapsed time:60.87 accuracy:81.34%
Epoch:43 Elapsed time:62.45 accuracy:81.55%
Epoch:44 Elapsed time:64.04 accuracy:81.73%
Epoch:45 Elapsed time:65.41 accuracy:81.56%
Epoch:46 Elapsed time:66.79 accuracy:81.73%
Epoch:47 Elapsed time:68.23 accuracy:81.68%
Epoch:48 Elapsed time:69.78 accuracy:81.74%
Epoch:49 Elapsed time:71.24 accuracy:81.91%
Epoch:50 Elapsed time:72.76 accuracy:81.87%
Epoch:51 Elapsed time:74.37 accuracy:81.91%
Epoch:52 Elapsed time:76.27 accuracy:82.09%
Epoch:53 Elapsed time:78.28 accuracy:82.04%
Epoch:54 Elapsed time:80.26 accuracy:82.15%
Epoch:55 Elapsed time:82.13 accuracy:82.20%
Epoch:56 Elapsed time:83.70 accuracy:82.29%
Epoch:57 Elapsed time:85.53 accuracy:82.27%
Epoch:58 Elapsed time:87.22 accuracy:82.38%
Epoch:59 Elapsed time:88.72 accuracy:82.36%
Epoch:60 Elapsed time:90.45 accuracy:82.47%
Epoch:61 Elapsed time:92.20 accuracy:82.39%
Epoch:62 Elapsed time:93.97 accuracy:82.48%
Epoch:63 Elapsed time:95.59 accuracy:82.50%
Epoch:64 Elapsed time:97.21 accuracy:82.61%
Epoch:65 Elapsed time:98.69 accuracy:82.65%
Epoch:66 Elapsed time:100.37 accuracy:82.78%
Epoch:67 Elapsed time:102.10 accuracy:82.74%
Epoch:68 Elapsed time:103.85 accuracy:82.77%
Epoch:69 Elapsed time:105.74 accuracy:82.79%
Epoch:70 Elapsed time:107.22 accuracy:82.91%
Epoch:71 Elapsed time:108.85 accuracy:82.94%
Epoch:72 Elapsed time:110.55 accuracy:83.04%
Epoch:73 Elapsed time:112.52 accuracy:83.05%
Epoch:74 Elapsed time:114.06 accuracy:83.08%
Epoch:75 Elapsed time:115.77 accuracy:83.09%
Epoch:76 Elapsed time:117.37 accuracy:83.16%
Epoch:77 Elapsed time:119.14 accuracy:83.26%
Epoch:78 Elapsed time:120.74 accuracy:83.20%
Epoch:79 Elapsed time:122.41 accuracy:83.23%
Epoch:80 Elapsed time:123.95 accuracy:83.23%
Epoch:81 Elapsed time:125.73 accuracy:83.24%
Epoch:82 Elapsed time:127.32 accuracy:83.31%
Epoch:83 Elapsed time:128.98 accuracy:83.36%
Epoch:84 Elapsed time:130.51 accuracy:83.39%
Epoch:85 Elapsed time:132.11 accuracy:83.47%
Epoch:86 Elapsed time:133.66 accuracy:83.47%
Epoch:87 Elapsed time:135.29 accuracy:83.48%
Epoch:88 Elapsed time:136.92 accuracy:83.55%
Epoch:89 Elapsed time:138.61 accuracy:83.50%
Epoch:90 Elapsed time:140.20 accuracy:83.60%
Epoch:91 Elapsed time:141.80 accuracy:83.58%
Epoch:92 Elapsed time:143.51 accuracy:83.66%
Epoch:93 Elapsed time:145.01 accuracy:83.63%
Epoch:94 Elapsed time:146.40 accuracy:83.66%
Epoch:95 Elapsed time:147.87 accuracy:83.71%
Epoch:96 Elapsed time:149.47 accuracy:83.79%
Epoch:97 Elapsed time:151.07 accuracy:83.79%
Epoch:98 Elapsed time:152.62 accuracy:83.84%
Epoch:99 Elapsed time:154.46 accuracy:83.79%
Epoch:100 Elapsed time:156.07 accuracy:83.85%

这个结果令人振奋。如果优化后,我们将取得更高的准确率。优化的途径如下:

上一篇下一篇

猜你喜欢

热点阅读