机器学习与数据挖掘我爱编程

tensorflow如何保存pb文件和读取pb文件

2018-06-27  本文已影响3179人  不讲道理的魏同学

pb文件的能够保存tensorflow计算图中的操作节点以及对应的各张量,方便我们日后直接调用之前已经训练好的计算图。
本文代码的运行软件为pycharm


保存pb文件

下面的代码展示了最简单的tensorflow四则运算计算图

import tensorflow as tf

x = tf.placeholder(tf.float32,name="input")

a = tf.Variable(tf.constant(5.,shape=[1]),name="a")
b = tf.Variable(tf.constant(6.,shape=[1]),name="b")
c = tf.Variable(tf.constant(10.,shape=[1]),name="c")
d = tf.Variable(tf.constant(2.,shape=[1]),name="d")

tensor1 = tf.multiply(a,b,"mul")
tensor2 = tf.subtract(tensor1,c,"sub")
tensor3 = tf.div(tensor2,d,"div")
result = tf.add(tensor3,x,"add")

inial = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(inial)
    print(sess.run(a))
    print(result)
    result = sess.run(result,feed_dict={x:1.0})
    print(result)
    constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["add"])
    with tf.gfile.FastGFile("wsj.pb", mode='wb') as f:
        f.write(constant_graph.SerializeToString())

保存pb文件的功能主要是通过最后三行代码实现的

constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["add"])
with tf.gfile.FastGFile("wsj.pb", mode='wb') as f:
    f.write(constant_graph.SerializeToString())

第一行代码的作用是将计算图中的变量转化为常量,并指定输出节点为“add”
第二行代码用来生成一个名为wsj.pb的文件(未指定路径的话,默认在该python代码的同路径下生成)
第三行代码的作用是将计算图写入该pb文件中

读取pb文件

import tensorflow as tf

with tf.gfile.FastGFile("wsj.pb", "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    result, x = tf.import_graph_def(graph_def,return_elements=["add:0", "input:0"])

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    print(sess.run(a))
    result = sess.run(result, feed_dict={x: 5.0})
    print(result)

上面代码主要分为两部分:读取pb文件并设置为默认的计算图;填充一个新的x值来计算结果。

读取pb文件时候需要注意的是,若要获取对应的张量必须用“tensor_name:0”的形式,这是tensorflow默认的。



若您觉得本文章对您有用,请您为我点上一颗小心心以表支持。感谢!

上一篇 下一篇

猜你喜欢

热点阅读