Tensorflow 训练好的模型保存和载入
方法一 这种存储方式在加载模型时需要再次定义网络结构
模型训练和存储
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
mnist = input_data.read_data_sets("/home/devops/test/TensorFlowOnSpark/mnist/",one_hot=True)
print (mnist)
learning_rate = 0.01
training_epochs = 5
batch_size = 100
display_step = 1
X = tf.placeholder(tf.float32,[None,784])
Y = tf.placeholder(tf.float32,[None,10])
W = tf.Variable(tf.zeros([784,10]),name="W")
b = tf.Variable(tf.zeros([10]),name="b")
pred = tf.nn.softmax(tf.matmul(X,W) + b)
cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(pred), reduction_indices =1))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
init = tf.global_variables_initializer()
##初始化存储器和存储路径
saver = tf.train.Saver(max_to_keep=4)
model_path = "./model/lr"
path = os.path.dirname(os.path.abspath(model_path))
if os.path.isdir(path) is False:
os.makedirs(path)
with tf.Session() as sess:
sess.run(init)
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples/batch_size)
for i in range(total_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
_,c = sess.run([optimizer,cost],feed_dict={X:batch_xs,Y:batch_ys})
avg_cost += c / total_batch
if (epoch + 1) % display_step == 0:
print ("Epoch:","%04d" % (epoch + 1),"cost=","{}".format(avg_cost))
saver.save(sess,model_path,write_meta_graph=True)
print ("Optimization Finished")
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(Y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
print ("Accuracy:",accuracy.eval({X:mnist.test.images[:3000],Y:mnist.test.labels[:3000]}))
加载模型
import tensorflow as tf
import os
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/home/devops/test/TensorFlowOnSpark/mnist/",one_hot=True)
X = tf.placeholder(tf.float32,[None,784])
Y = tf.placeholder(tf.float32,[None,10])
with tf.Session() as sess:
saver = tf.train.import_meta_graph("/home/devops/test/TensorFlowOnSpark/examples/mnist/my/curve/model/lr.meta")
saver.restore(sess,tf.train.latest_checkpoint("/home/devops/test/TensorFlowOnSpark/examples/mnist/my/curve/model"))
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("W:0")
b = graph.get_tensor_by_name("b:0")
pred = tf.nn.softmax(tf.matmul(X,W) + b)
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(Y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
print ("Accuracy:",accuracy.eval({X:mnist.test.images[:3000],Y:mnist.test.labels[:3000]}))
方法二 这种存储方式在加载模型时不用定义网络结构
模型训练和存储
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
mnist = input_data.read_data_sets("/home/devops/test/TensorFlowOnSpark/mnist/",one_hot=True)
print (mnist)
learning_rate = 0.01
training_epochs = 5
batch_size = 100
display_step = 1
X = tf.placeholder(tf.float32,[None,784],name="X")
Y = tf.placeholder(tf.float32,[None,10],name="Y")
W = tf.Variable(tf.zeros([784,10]),name="W")
b = tf.Variable(tf.zeros([10]),name="b")
pred = tf.nn.softmax(tf.matmul(X,W) + b)
cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(pred), reduction_indices =1))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
init = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=4)
##把要加载的对象提前加入集合
tf.add_to_collection("pred",pred)
model_path = "./model/lr"
path = os.path.dirname(os.path.abspath(model_path))
if os.path.isdir(path) is False:
os.makedirs(path)
with tf.Session() as sess:
sess.run(init)
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples/batch_size)
for i in range(total_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
_,c = sess.run([optimizer,cost],feed_dict={X:batch_xs,Y:batch_ys})
avg_cost += c / total_batch
if (epoch + 1) % display_step == 0:
print ("Epoch:","%04d" % (epoch + 1),"cost=","{}".format(avg_cost))
saver.save(sess,model_path,write_meta_graph=True)
print ("Optimization Finished")
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(Y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
print ("Accuracy:",accuracy.eval({X:mnist.test.images[:3000],Y:mnist.test.labels[:3000]}))
模型加载
import tensorflow as tf
import os
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/home/devops/test/TensorFlowOnSpark/mnist/",one_hot=True)
with tf.Session() as sess:
saver = tf.train.import_meta_graph("/home/devops/test/TensorFlowOnSpark/examples/mnist/my/curve/model/lr.meta")
saver.restore(sess,tf.train.latest_checkpoint("/home/devops/test/TensorFlowOnSpark/examples/mnist/my/curve/model"))
pred = tf.get_collection("pred")[0]
graph = tf.get_default_graph()
X = graph.get_operation_by_name("X").outputs[0]
Y = graph.get_operation_by_name("Y").outputs[0]
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(Y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
print ("Accuracy:",accuracy.eval({X:mnist.test.images[:300],Y:mnist.test.labels[:300]}))
1.Tensorflow模型文件的组成
主要包含两个文件
- 元图 meta graph
保存完整的图结构 包含所有的变量 操作等 扩展名为meta
2.检查点文件 checkpoint
二进制文件 包含所有的权重 偏差 梯度和其他所有保存的值 扩展名是.ckpt , 0.11版本之后不再仅使用一个.ckpt文件来表示了 而是两个文件 .data-00000-of-00001 和.index
其中.data 是包含训练变量的文件
此外还有一个名为checkpoint的文件 用于保存最新检查点的记录
2.如何保存Tensorflow模型
在模型训练完成后 可调用tf.train.Saver()实例来保存所有的参数和计算图
由于tensorflow中的变量只能存在于session中,因此需要在session中调用save 将模型存储
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2],name='w1’))
w2 = tf.Variable(tf.random_normal(shape=[5]),name=‘w2’)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variable_initialize())
saver.save(sess,’/path/to/save')
运行后可得以下文件:
model/
├── checkpoint
├── my_test_model.data-00000-of-00001
├── my_test_model.index
└── my_test_model.meta
如果想在1000次迭代之后再保存模型,可通过传递步数来调用save
saver.save(sess,’model_path’,global_step=1000)
image.png如果想每1000次保存一下模型,由于.meta文件会在第一次保存时创建 而且图结构不会再变化,因此只需要保存模型进一步迭代的数据 而不用存储网络结构 可调用
saver.save(sess,’model_path’,global_step=step,write_meta_graph=False)
如果只想保存最新的4个模型参数,并且希望在训练阶段每两小时保存一个模型,可调用
saver = tf.train.Saver(max_to_keep=4,keep_checkpoint_every_n_hours=2)
如果在tf.train.Saver() 中没有指定任何东西,那么他会保存模型的所有变量,如果只想保存部分变量则需要通过列表或字典的形式将变量传递进去
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2],name='w1’))
w2 = tf.Variable(tf.random_normal(shape=[5]),name=‘w2’)
saver = tf.train.Saver([w1,w2])
with tf.Session() as sess:
sess.run(tf.global_variable_initializer())
sess.run(tf.global_variable_initialize())
saver.save(sess,’/path/to/save')
3.如何导入一个训练好的模型并进行修改和微调
需要完成两件事情
1.构建网络结构
可通过手动编写代码创建每一层网络结构来重构整个网络
保存模型时也会将网络结构存储到.meta文件中,可直接调用tf.train.import()函数来导入这个模型
saver = tf.train.import_meta_graph(‘model-1000.meta’) 这个操作是将.meta文件中的计算图数据直接附加到当前定义的图中,但是我们仍然需要去加载计算图上所有已经训练好的权重参数
2.加载参数
new_saver.restore(sess,tf.train.latest_checkpoint('./‘)) checkpoint文件所在路径
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph(‘my_test_model-1000.meta’)
new_saver.restore(sess,tf.train.latest_checkpoint('./'))
#读取参数
print(sess.run(‘w1:0'))
4.恢复任何预先训练好的模型用于预测 (工作中的开发方式)
import tensorflow as tf
w1 = tf.placeholder(‘float’,name='w1’)
w2 = tf.placeholder(‘float’,name=‘w2ww’)
b1 = tf.Variable(2.0,name=‘bias’)
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name=‘op_to_restore’)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print (sess.run(24,feed_dict={w1:4,w2:8}))
saver.save(sess,’test_model’,global_step=1000)
当需要载入这个模型时,不仅需要恢复所有的计算图和权重参数 还需要准备一个新的feed_dict
用于将新的训练数据传送到网络中进行训练,可通过graph.get_tensor_by_name() 来获得对这些保存的操作和占位符变量的引用
w1 = graph.get_tensor_by_name(‘w1:0’)
op_to_restore = graph.get_tensor_by_name(“op_to_restore:0”)
使用不同的数据来运行相同的网络 则需要通过feed_dict来传递数据
with tf.Session() as sess:
saver = tf.train.import_meta_graph(’test_model-1000.meta’)
saver.restore(sess,tf.train.latest_checkpoint(‘./‘))
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name(“w1:0”)
w2 = graph.get_tensor_by_name(“w2:0”)
feed_dict= {w1:13.0,w2:17.0}
op_to_restore = graph.get_tensor_by_name(‘op_to_restore:0’)
print (sess.run(op_to_restore,feed_dict))
如果想在原来的计算图基础上添加更多的操作和图层,并进行训练
import tensorflow as tf
with tf.Session() as sess:
saver = tf.train.import_meta_graph(‘my_test_model-1000.meta’)
saver.restore(sess,tf.train.latest_checkpoint(‘./‘))
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name(‘w1:0’)
w2 = graph.get_tensor_by_name(‘w2:0’)
feed_dict = {w1:13,w2:17}
op_to_restore = graph.get_tensor_by_name(‘op_to_restore:0’)
#新添加操作
add_on_op = tf.multiply(op_to_restore,2)
print (sess.run(add_on_op,feed_dict))