码农的世界人工智能程序员

JAVAEE与人工智能实战之--通过MNIST进行模型训练

2019-04-08  本文已影响5人  山东大葱哥

MNIST简介

一个手写数字识别库,世界上最权威的,美国邮政系统开发的,手写内容是0-9的内容,手写内容采集于美国人口调查局的员工和高中生。包括6万张训练图片和1万张测试图片构成的,每张图片都是28*28大小,而且都是黑白色构成。

MINIST实验包含了四个文件,其中train-images-idx3-ubyte是60000个图片样本,train-labels-idx1-ubyte是这60000个图片对应的数字标签,t10k-images-idx3-ubyte是用于测试的样本,t10k-labels-idx1-ubyte是测试样本对应的数字标签。

我们以测试集中的一个图片为例来说明图片的存储形式:

MNIST图片并不是传统意义上的png或者jpg格式的图片,因为png或者jpg的图片格式,会带有很多干扰信息(如:数据块,图片头,图片尾,长度等等),这些图片会被处理成很简易的数组,图片长度为28,宽度也为28,总像素为2828=784,在MNIST存储的就是一个长度为784的数组,数组中的每个值表示每个点的RGB值,其中黑色用0表示、白色用255表示。我们可以将数组转成2828的二维数组,如下图所示,可以看出这是一个表示的是数字5的图片。

image.png

如果把像素写成图片,图片是这样的:


image.png

通过MNIST训练模型

在BP神经网络中, 层数、节点个数、学习速率、训练集、训练次数,都会影响到最终模型的泛化能力。因此,在设计模型时,节点的个数,学习速率的大小,以及训练次数都是需要考虑的。

本实例中设置神经网络层数为3层,其中输入特征为784个,每层节点数分别为300、100、10个,学习速率设置为0.5,迭代周期为30,批量设置60个。通过训练该模型在MNIST测试集上的平均准确率为96.68 %左右。

public static void main(String[] args) {
        //三层网络,各层节点数为784*300*10 输入特征 784个  隐藏层节点300个 输出层节点10个
        int[] nodeNum = {784, 300,100, 10};
        //周期被定义为向前和向后传播中所有批次的单次训练迭代。
        int epoch = 30;
        //每次批量的样本数
        int batchSize = 60;
        double learningRate=0.5;
        NetTrainAndTest.train(nodeNum, epoch, batchSize,learningRate);
    }

对模型进行序列化

为了“一次训练、多次使用”,我们对训练好的模型进行序列化存储,后续即可通过反序列化的方式读取恢复模型。

    /**
     * 通过序列化方式存储模型
     *
     * @param fileName 模型存放的文件名
     */
    public static <T> void saveModel(String fileName, T obj) {
        try (BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(fileName));
             ObjectOutputStream oos = new ObjectOutputStream(bos)) {
            oos.writeObject(obj);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 恢复模型
     *
     * @param fileName 模型持久化的存放位置 文件名
     *                 <p>
     *                 //@SuppressWarnings("unchecked")
     */
    public static <T> T restoreModel(String fileName) {
        try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(fileName));
             ObjectInputStream ois = new ObjectInputStream(bis)) {
            return (T) ois.readObject();
        } catch (IOException | ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

上一篇 JAVAEE与人工智能目录 [下一篇]
上一篇下一篇

猜你喜欢

热点阅读