JAVAEE与人工智能实战之--通过MNIST进行模型训练
MNIST简介
一个手写数字识别库,世界上最权威的,美国邮政系统开发的,手写内容是0-9的内容,手写内容采集于美国人口调查局的员工和高中生。包括6万张训练图片和1万张测试图片构成的,每张图片都是28*28大小,而且都是黑白色构成。
MINIST实验包含了四个文件,其中train-images-idx3-ubyte是60000个图片样本,train-labels-idx1-ubyte是这60000个图片对应的数字标签,t10k-images-idx3-ubyte是用于测试的样本,t10k-labels-idx1-ubyte是测试样本对应的数字标签。
我们以测试集中的一个图片为例来说明图片的存储形式:
MNIST图片并不是传统意义上的png或者jpg格式的图片,因为png或者jpg的图片格式,会带有很多干扰信息(如:数据块,图片头,图片尾,长度等等),这些图片会被处理成很简易的数组,图片长度为28,宽度也为28,总像素为2828=784,在MNIST存储的就是一个长度为784的数组,数组中的每个值表示每个点的RGB值,其中黑色用0表示、白色用255表示。我们可以将数组转成2828的二维数组,如下图所示,可以看出这是一个表示的是数字5的图片。
如果把像素写成图片,图片是这样的:
image.png
通过MNIST训练模型
在BP神经网络中, 层数、节点个数、学习速率、训练集、训练次数,都会影响到最终模型的泛化能力。因此,在设计模型时,节点的个数,学习速率的大小,以及训练次数都是需要考虑的。
本实例中设置神经网络层数为3层,其中输入特征为784个,每层节点数分别为300、100、10个,学习速率设置为0.5,迭代周期为30,批量设置60个。通过训练该模型在MNIST测试集上的平均准确率为96.68 %左右。
public static void main(String[] args) {
//三层网络,各层节点数为784*300*10 输入特征 784个 隐藏层节点300个 输出层节点10个
int[] nodeNum = {784, 300,100, 10};
//周期被定义为向前和向后传播中所有批次的单次训练迭代。
int epoch = 30;
//每次批量的样本数
int batchSize = 60;
double learningRate=0.5;
NetTrainAndTest.train(nodeNum, epoch, batchSize,learningRate);
}
对模型进行序列化
为了“一次训练、多次使用”,我们对训练好的模型进行序列化存储,后续即可通过反序列化的方式读取恢复模型。
/**
* 通过序列化方式存储模型
*
* @param fileName 模型存放的文件名
*/
public static <T> void saveModel(String fileName, T obj) {
try (BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(fileName));
ObjectOutputStream oos = new ObjectOutputStream(bos)) {
oos.writeObject(obj);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* 恢复模型
*
* @param fileName 模型持久化的存放位置 文件名
* <p>
* //@SuppressWarnings("unchecked")
*/
public static <T> T restoreModel(String fileName) {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(fileName));
ObjectInputStream ois = new ObjectInputStream(bis)) {
return (T) ois.readObject();
} catch (IOException | ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
上一篇 | JAVAEE与人工智能目录 | [下一篇] |
---|