TensorFlow学习7:输入图片,预测结果

2018-06-15  本文已影响0人  崔业康

代码处理过程

1,模型的要求是黑底白字,但输入的图是白底黑字,所以需要对每个像素点的值改为255减去原值以得到互补的反色
2,对图片做二值化处理
3,把图片形状拉成1行784列,并把值变成浮点型(要求像素点是0-1之间的浮点数)
4,计算求得输出y,y的最大值所对应的列表索引号就是预测结果

示例代码

#coding:utf-8
#将符合神经网络输入要求的图片喂给复现的神经网络模型,输出预测值
def restore_model(testPicArr):
    #创建一个默认图,在该图中执行以下操作
    with tf.Graph().as_default() as tg:
        x=tf.placeholder(tf.float32,[None,mnist_forword.INPUT_NONE])
        y=mnist_forword.mnist_forword(x,None)
        #得到概率最大的预测值
        preValue=tf.argmax(y,1)

        #实现滑动平均模型,参数MOVING_AVERAGE_DECAY用于控制模型更新的速度
        #训练过程中会对每一个变量维护一个影子变量,这个影子变量的初始值
        #就是相应变量的初始值,每次变量更新时,影子变量就会随之更新
        variable_averages=tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
        variable_to_restore=variable_averages.variable_to_restore()
        saver=tf.train.Saver(variable_to_restore)

        with tf.session() as sess:
            #通过checkpoint文件定位到最新保存的模型
            ckpt=tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess,ckpt.model_checkpoint_path)

                preValue=sess.run(preValue,feed_dict={x:testPicArr})
                return preValue
            else:
                print("No checkpoint file found")
                return -1

#预处理函数,包括resize,转变灰度图,二值化操作
def pre_pic(picName):
    img=Image.open(picName)
    prIm=img.resize((28,28),Image.ANTIALIAS)
    im_arr=np.array(reIm.convert('L'))
    #设定合理的阙值
    threshold=50
    for i in range(28):
        for j in range(28):
            im_arr[i][j]=255-im_arr[i][j]
            if(im_arr[i][j]<threshold):
                im_arr[i][j]=0
            else:
                im_arr[i][j]=255
    nm_arr=im_arr.reshape([1,784])
    nm_arr=nm_arr.astype(np.float32)
    img_ready=np.multiply(nm_arr,1.0/255.0)

    return img_ready

参考:人工智能实践:Tensorflow笔记

上一篇下一篇

猜你喜欢

热点阅读