tf的保存和恢复
保存模型:
import tensorflow as tf
Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}
Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
#Create a saver object which will save all the variables
saver = tf.train.Saver()
#Run the operation by feeding input
print(sess.run(w4,feed_dict))
#Prints 24 which is sum of (w1+w2)*b1
#Now, save the graph
saver.save(sess, './my_test_model',global_step=1020)#global_step记录循环第几次
必须强调的是:这里4,5,6,11行中的name=’w1′, name=’w2′, name=’bias’, name=’op_to_restore’ 千万不能省略,这是恢复还原模型的关键。那其他tf的op相关的就不用name了???或者默认是否是和变量名如w4相同?
恢复和使用
import tensorflow as tf
sess=tf.Session()
First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1020.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
应该是把.meta前的那一部分都输入作为目录,此处是my_test_model-1020
tf.train.latest_checkpoint()来自动获取最后一次保存的模型
Access saved Variables directly
print(sess.run('bias:0'))#0在这里的作用是???以及此时是否恢复了w?——恢复了占位符而已,见下文尝试。
print(sess.run('w1:0'))提示没有赋值。因为只是个placeholder?那对于神经网络里训练好的模型的权重,怎么恢复来用?--variable的值会保存。上面print(sess.run('bias:0'))就是证明。
This will print 2, which is the value of bias that we saved
Now, let's access and create placeholders variables and
create feed-dict to feed new data
graph = tf.get_default_graph()#restore没有恢复图吗?为何还要再来一次?——应该是为了下面调用getxxx函数。所以创建这么一个对象。
w1 = graph.get_tensor_by_name("w1:0")#要重新启动这个占位符,把模型/图里的w1赋值给一个本地变量,可以命名为w1,也可以是其他如ww1,便于后面的操作如feed_dict。
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
print(sess.run(op_to_restore,feed_dict))
This will print 60 which is calculated
w3 = graph.get_tensor_by_name("w3:0")报错:"The name 'w3:0' refers to a Tensor which does not exist. The operation, 'w3', does not exist in the graph."
总结:需要get_tensor_by_xxx应该都是为了创造本地变量从而feed数据而已,没有与本地交互的,如w3就不用再get,实际上已恢复在图中但不用交互。
'''
网上查到两种常用方法对比:
方法1:
保存
定义变量
使用saver.save()方法保存
import tensorflow as tf
import numpy as np
W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w')
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b')
init = tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess,"save/model.ckpt")
载入
定义变量
使用saver.restore()方法载入
import tensorflow as tf
import numpy as np
W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w')
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b')
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,"save/model.ckpt")
在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。
方法二、不需重新定义网络结构的方法
'''
定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')
w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
创建saver
saver = tf.train.Saver()# defaults to saving all variables - in this case w and b
假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
sess.run(train_op)
if step % 1000 == 0:
# 保存checkpoint, 同时也默认导出一个meta_graph
# graph名为'my-model-{global_step}.meta'.
saver.save(sess, 'my-model', global_step=step)
with tf.Session() as sess:
# new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
# new_saver.restore(sess, 'my-save-dir/my-model-10000')
new_saver = tf.train.import_meta_graph('mmodel.ckpt-25.meta')
new_saver.restore(sess, 'mmodel.ckpt-25')
tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
y = tf.get_collection('pred_network')[0]
graph = tf.get_default_graph()
# 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
input_x = graph.get_operation_by_name('input_x').outputs[0]
keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]
# 使用y进行预测
sess.run(y, feed_dict={input_x:...., keep_prob:1.0})