深度学习|tensorflow识别手写字体
2019-03-02 本文已影响14人
罗罗攀
我们依旧以MNIST手写字体数据集,来看看我们如何使用tensorflow来实现MLP。
数据
数据下载
这里我们通过tensorflow的模块,来下载数据集。
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
这样,我们就下载了数据集,这里的one_hot的意思是label为独热编码,也就是说我们的label就不需要预处理了。
数据情况
我们通过下面代码看看数据的情况:
- 55000训练集
- 5000验证集
- 10000测试集
MLP模型
之前我们使用过keras进行训练,只需要建立一个model,然后add加入神经网络层。tensorflow是要复杂很多,那我们一步步构建我们的模型吧。
- 首先是输入层,我们用placeholder来输入
- 隐含层256个神经元
- 输出10个神经元
def layer(output_dim,input_dim,inputs, activation=None):
W = tf.Variable(tf.random_normal([input_dim, output_dim]))
b = tf.Variable(tf.random_normal([1, output_dim]))
XWb = tf.matmul(inputs, W) + b
if activation is None:
outputs = XWb
else:
outputs = activation(XWb)
return outputs
x = tf.placeholder("float", [None, 784])
h1=layer(output_dim=256,input_dim=784,
inputs=x ,activation=tf.nn.relu)
y_predict=layer(output_dim=10,input_dim=256,
inputs=h1,activation=None)
定义损失函数
这里我们需要自己定义函数,并进行优化处理。
y_label = tf.placeholder("float", [None, 10])
loss_function = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits
(logits=y_predict ,
labels=y_label))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001) \
.minimize(loss_function)
准确性评价
correct_prediction = tf.equal(tf.argmax(y_label , 1),
tf.argmax(y_predict, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
训练
训练我们定义15轮。
trainEpochs = 15
batchSize = 100
totalBatchs = int(mnist.train.num_examples/batchSize)
epoch_list=[];loss_list=[];accuracy_list=[]
from time import time
startTime=time()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(trainEpochs):
for i in range(totalBatchs):
batch_x, batch_y = mnist.train.next_batch(batchSize)
sess.run(optimizer,feed_dict={x: batch_x,y_label: batch_y})
loss,acc = sess.run([loss_function,accuracy],
feed_dict={x: mnist.validation.images,
y_label: mnist.validation.labels})
epoch_list.append(epoch);loss_list.append(loss)
accuracy_list.append(acc)
print("Train Epoch:", '%02d' % (epoch+1), "Loss=", \
"{:.9f}".format(loss)," Accuracy=",acc)
duration =time()-startTime
print("Train Finished takes:",duration)