TensorFlow操作

TensorFlow 加载部分变量

2018-07-13  本文已影响0人  翻开日记
for v in tf.global_variables():
    if 'global_step' in v.name:
        var2.append(v) # 重置global step 调整学习率重新训练
    else:
        var1.append(v)
var1 = tf.train.Saver(var1)
...
"""Restore Model"""
save_file = tf.train.latest_checkpoint(hparams_at.train_dir)
if save_file:
    print(save_file)
    step = int(save_file.split('ckpt-')[-1]) + 1
    var1.restore(sess, save_file)
    sess.run(tf.variables_initializer([model.global_step]))
    print("Go on")
else:
    step = 0
    sess.run(tf.global_variables_initializer())
    print("Begin")

上一篇 下一篇

猜你喜欢

热点阅读