大数据 爬虫Python AI Sql程序员机器学习与数据挖掘

深度学习|tensorflow识别手写字体

2019-03-02  本文已影响14人  罗罗攀

我们依旧以MNIST手写字体数据集,来看看我们如何使用tensorflow来实现MLP。

数据

数据下载

这里我们通过tensorflow的模块,来下载数据集。

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

这样,我们就下载了数据集,这里的one_hot的意思是label为独热编码,也就是说我们的label就不需要预处理了。

数据情况

我们通过下面代码看看数据的情况:

MLP模型

之前我们使用过keras进行训练,只需要建立一个model,然后add加入神经网络层。tensorflow是要复杂很多,那我们一步步构建我们的模型吧。

def layer(output_dim,input_dim,inputs, activation=None):
    W = tf.Variable(tf.random_normal([input_dim, output_dim]))
    b = tf.Variable(tf.random_normal([1, output_dim]))
    XWb = tf.matmul(inputs, W) + b
    if activation is None:
        outputs = XWb
    else:
        outputs = activation(XWb)
    return outputs

x = tf.placeholder("float", [None, 784])

h1=layer(output_dim=256,input_dim=784,
         inputs=x ,activation=tf.nn.relu) 

y_predict=layer(output_dim=10,input_dim=256,
                    inputs=h1,activation=None)
定义损失函数

这里我们需要自己定义函数,并进行优化处理。

y_label = tf.placeholder("float", [None, 10])

loss_function = tf.reduce_mean(
                  tf.nn.softmax_cross_entropy_with_logits
                         (logits=y_predict , 
                          labels=y_label))

optimizer = tf.train.AdamOptimizer(learning_rate=0.001) \
                    .minimize(loss_function)
准确性评价
correct_prediction = tf.equal(tf.argmax(y_label  , 1),
                              tf.argmax(y_predict, 1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
训练

训练我们定义15轮。

trainEpochs = 15
batchSize = 100
totalBatchs = int(mnist.train.num_examples/batchSize)
epoch_list=[];loss_list=[];accuracy_list=[]
from time import time
startTime=time()
sess = tf.Session()
sess.run(tf.global_variables_initializer())

for epoch in range(trainEpochs):
    for i in range(totalBatchs):
        batch_x, batch_y = mnist.train.next_batch(batchSize)
        sess.run(optimizer,feed_dict={x: batch_x,y_label: batch_y})
        
    loss,acc = sess.run([loss_function,accuracy],
                        feed_dict={x: mnist.validation.images, 
                                   y_label: mnist.validation.labels})

    epoch_list.append(epoch);loss_list.append(loss)
    accuracy_list.append(acc)    
    print("Train Epoch:", '%02d' % (epoch+1), "Loss=", \
                "{:.9f}".format(loss)," Accuracy=",acc)
    
duration =time()-startTime
print("Train Finished takes:",duration)  
测试加预测
上一篇下一篇

猜你喜欢

热点阅读