mnist网络框架搭建

2017-09-20  本文已影响0人  苟且偷生小屁屁

首先,定义网络模型mnist_model

from keras.models import Sequential
from keras.layers import Conv2d,MaxPool2d,Flatten,Dense
图片.png
def mnist_model():

# 第一卷积层:
mnist_model.add(Conv2D(6,(5,5),activation='relu',input_shape=input_shape))
#第一池化层
mnist_model.add(MaxPool2D(pool_size=(2,2)))
#第二卷积层,这层不需要指定input
mnist_model.add(Conv2D(16,(5,5),activation='relu'))
#第二池化层
mnist_model.add(MaxPool2D(pool_size))
Flatten层
mnist.model.add(Flatten())
Dense层
mnist_model.add(Dense(120,activation='relu'))
Dense层
mnist_model.add(Dense(120,activation='relu'))
return mnist_mode

model函数部分到此为止


在主函数部分,首先也是导入依赖的数据库

import model  # 这是自己定义的模型函数
import numpy  # 涉及reshape的操作,需要依赖numpy.reshape
from keras.dataset import mnist  #导入keras自带的mnist数据库

首先,分出训练集和测试集,

(x_train,y_train),(x_test,y_test)=mnist.load_data()

然后,这里的x_train是60000,28,28的数组tuple,因为第一层必须指定input_shape,一般情况下,tensorflow的变量结构为:“数量,行,列,颜色通道”。(不好说啊,,)

if keras.backend.image_data_format() = 'chanel_first'
# 此时希望数组的维度从60000,28,28转到60000,1,28,28
x_train = numpy.reshape(x_train,[x_train.shape[0],1,28,28])
x_test = numpy.reshape(x_test,[x_test.reshape[1],1,28,28])
else
x_train = numpy.reshape(x_train,[x_train.shape[0],28,28,1])
x_test = numpy.reshape(x_test,[x_train.shape[0],28,28,1])
# 一般情况下是后一种,也是tensorflow变量的标准存储形式

然后,是一些基本参数,涉及到batch_size, epoch, input_shape,n_class

batch_size = 32
epoch = 12
n_class = 10

然后是y_train,y_test转换成one-hot形式

y_train = keras.utils.np_utils.to_categorical(y_train,n_class)
y_test = keras.utils.np_utils.to_categorical(y_test,n_class)
model = model.mnist_model(input_data, n_class)
model.compile(optimizer='Adadelta', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train,y_train,batch_size,epoch,verbose=1)

载入模型,就是将模型放在这,
然后是编译模型,最重要的三个参数,一个是优化器,一个是损失,一个是性能评估
最后是训练模型,需要训练数据,训练标签,batch_size,epoch,
verbose是是否显示进度条的选项。
如果需要保存model

keras.models.save_model(temp_model,'mnist.h5')
del temp_model
temp_model = load_model('mnist.h5')
score = temp_model.evaluate(x_test,y_test,verbose=0)
print('test loss:',score[0])
print('test accuracy:',score[1])
上一篇 下一篇

猜你喜欢

热点阅读