Tensorflow

使用单个模型文件进行预测

2018-06-29  本文已影响0人  拉赫曼

先上代码

import tensorflow as tf
import numpy as np
from tensorflow.python.platform import gfile
from tensorflow.python.lib.io import file_io

input_tensor_key = 'Placeholder:0'

def loadNpData(filename):
    tensor_key_feed_dict = {}

    #inputs = preprocess_inputs_arg_string(inputs_str)
    data = np.load(file_io.FileIO(filename, mode='r'))

    # When no key is specified for the input file.
    # Check if npz file only contains a single numpy ndarray.
    if isinstance(data, np.lib.npyio.NpzFile):
        variable_name_list = data.files
        if len(variable_name_list) != 1:
            raise RuntimeError(
                'Input file %s contains more than one ndarrays. Please specify '
                'the name of ndarray to use.' % filename)
        tensor_key_feed_dict[input_tensor_key] = data[variable_name_list[0]]
    else:
        tensor_key_feed_dict[input_tensor_key] = data
    return tensor_key_feed_dict

with tf.Session() as sess:
    # 定义模型文件及样本测试文件
    model_filename = 'merge1_graph.pb'
    example_png = 'examples.npy'
    # 加载npy格式的图片测试样本数据
    image_data = loadNpData(example_png)
    #加载模型文件
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef();
        graph_def.ParseFromString(f.read())

    # 获取输入节点的tensor
    inputs = sess.graph.get_tensor_by_name("Placeholder:0");
    #打印输入节点的信息
    #print inputs
    # 导入计算图,定义输入节点及输出节点
    output = tf.import_graph_def(graph_def, input_map={'Placeholder:0':inputs}, return_elements=[ 'ArgMax:0','Softmax:0']) 
    # 打印输出节点的信息
    #print output
    results = sess.run(output, feed_dict={inputs:image_data[input_tensor_key]})
    print 'ArgMax result(预测结果对应的标签值):'  
    print results[0]
    print 'Softmax result(最后一层的输出):'
    print results[1]
    # 输出node详细信息,此处默认只打印第一个节点
    for node in graph_def.node:
        print node
        break

运行输出

ArgMax result(预测结果对应的标签值):
[3 3]
Softmax result(最后一层的输出):
[[4.1668140e-12 9.0696268e-18 6.4261091e-13 9.9999940e-01 1.7161388e-30
  5.4321697e-07 7.6357951e-09 6.3293229e-19 1.3812791e-13 1.5360580e-12]
 [1.1472046e-05 3.3404951e-10 6.0365837e-09 9.9997592e-01 9.8635665e-15
  5.7557719e-07 1.1977763e-05 1.6275100e-16 7.2288098e-10 5.0601763e-08]]

此处加载的关键在于tf.import_graph_def函数的参数配置,三个参数graph_def input_map return_elements

第一个参数是导入的图
input_map是指定输入节点,如果不指定,后面run的时候会报错 ==You must feed a value for placeholder tensor 'Placeholder'==

return_elements 是指定运算后的输出节点,此处就是我们想要得到的标签估计值 ArgMax 以及 最后一层节点输出 Softmax

模型的测试参考 将Tensorflow模型导出为单个文件

上一篇下一篇

猜你喜欢

热点阅读