TensorFlow

tensorflow中检查点的使用

2018-11-12  本文已影响3人  上行彩虹人

保存模型并不限于在训练之后,在训练之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况。我们自然希望能够将辛苦得到的中间参数保留下来,否则下次又要重新开始。这种在训练中保存模型,习惯上称之为保存检查点。
1、线性回归例子

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

#训练数据
train_x = np.linspace(-1,1,100)
train_y = 2* train_x + np.random.randn(*train_x.shape)*0.3

tf.reset_default_graph()

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

w = tf.Variable(tf.random_normal([1]),name='weight')
b = tf.Variable(tf.zeros([1]),name='bias')
predict = tf.multiply(w,x)+b

cost = tf.reduce_mean(tf.square(y-predict))
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 200
display_step= 2

2、保存检查点

#                    max_to_keep 保存的检查点个数
saver = tf.train.Saver(max_to_keep=2)
savedir = 'log/'
with tf.Session() as sess:
    sess.run(init)
    for epoch in range(training_epochs):
        # for(x,y)  in zip(train_x,train_y):
        sess.run(optimizer,feed_dict={x:train_x,y:train_y})
        loss = sess.run(cost,feed_dict={x:train_x,y:train_y})
        print('epoch:',epoch,'loss',loss)
        #保存检查点
        saver.save(sess,savedir+'linemodel.cpkt',global_step=epoch)
    print('Finish')

    plt.plot(train_x,train_x,color='green')
    plt.plot(train_x,sess.run(w)*train_x+sess.run(b),color='red')
    plt.legend()
    plt.show()

log文件夹下生成的文件

xia

3、另起一个session载入保存的检查点

with tf.Session() as sess2:
    sess2.run(init)
    saver.restore(sess2,savedir+'linemodel.cpkt-'+str(198))
    print(sess2.run(w))
    print(10*sess2.run(w)+sess2.run(b))
上一篇 下一篇

猜你喜欢

热点阅读