tensorflow断点续训
2019-03-02 本文已影响0人
一位学有余力的同学
在进行神经网络训练过程中由于一些因素导致训练无法进行,需要保存当前的训练结果下次接着训练
全连接反向传播神经网络中,训练过程的代码如下:
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
#加入断点续训功能
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)
for i in range(STEPS):
xs,ys = mnist.train.next_batch(BATCH_SIZE)
_,loss_value,step = sess.run([train_op,loss,global_step],feed_dict={x:xs,y_:ys})
if i % 1000 == 0:
print("Ater {} training step(s),loss on training batch is {} ".format(step,loss_value))
saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir, latest_filename=None)
该函数表明如果断点文件夹中包含有效断点状态文件,则返回该文件。
checkpoint_dir:存储断点文件目录
latest_filename=None:断点文件的可选名称,默认为“checkpoint”
saver.restore(sess, ckpt.model_checkpoint_path)
该函数表示恢复当前会话sess,将ckpt中的值赋给w和b
sess:表示当前会话,之前保存的结果将被加载入这个会话
ckpt.model_checkpoint_path:表示模型的存储位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新的是谁,叫什么
其中:
MODEL_SAVE_PATH = './model/'
MODEL_NAME = 'mnist_model'
saver.save会在‘./model’中自动保存checkpoint文件,然后实现断点训练只需要在训练前添加下列代码即可
断点训练代码:
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)