机器学习

tensorflow断点续训

2019-03-02  本文已影响0人  一位学有余力的同学

在进行神经网络训练过程中由于一些因素导致训练无法进行,需要保存当前的训练结果下次接着训练
全连接反向传播神经网络中,训练过程的代码如下:

saver = tf.train.Saver()

with tf.Session() as sess:
   init_op = tf.global_variables_initializer()
   sess.run(init_op)

   #加入断点续训功能
   ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
   if ckpt and ckpt.model_checkpoint_path:
      saver.restore(sess,ckpt.model_checkpoint_path)

   for i in range(STEPS):
      xs,ys = mnist.train.next_batch(BATCH_SIZE)
      _,loss_value,step = sess.run([train_op,loss,global_step],feed_dict={x:xs,y_:ys})
      if i % 1000 == 0:
         print("Ater {} training step(s),loss on training batch is {} ".format(step,loss_value))
         saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step)

ckpt = tf.train.get_checkpoint_state(checkpoint_dir, latest_filename=None)

该函数表明如果断点文件夹中包含有效断点状态文件,则返回该文件。
checkpoint_dir:存储断点文件目录
latest_filename=None:断点文件的可选名称,默认为“checkpoint”

saver.restore(sess, ckpt.model_checkpoint_path)

该函数表示恢复当前会话sess,将ckpt中的值赋给w和b
sess:表示当前会话,之前保存的结果将被加载入这个会话
ckpt.model_checkpoint_path:表示模型的存储位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新的是谁,叫什么

其中:

MODEL_SAVE_PATH = './model/'
MODEL_NAME = 'mnist_model'

saver.save会在‘./model’中自动保存checkpoint文件,然后实现断点训练只需要在训练前添加下列代码即可
断点训练代码:

ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)

原文链接

上一篇下一篇

猜你喜欢

热点阅读