Tensorflow Error笔记3

2017-07-11  本文已影响0人  BookThief

愿天堂没有Tensorflow! 阿门。

NotFoundError (see above for traceback): Key local3/weights not found in checkpoint

这是一个困扰我好久的问题,在我们保存一个训练好的模型,然后找了一些测试数据来调用该模型测试模型的效果时,出现了上述错误,local3/weights可能会随机变化(比如conv1/weights)。下面调用模型的代码是Tensorflow官网上的。

with tf.Session() as sess:
                 tf.get_variable_scope().reuse_variables()             
                 print("Reading checkpoints...")
                 ckpt = tf.train.get_checkpoint_state(logs_train_dir)
                 if ckpt and ckpt.model_checkpoint_path:
                     global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                     saver.restore(sess, ckpt.model_checkpoint_path)
                     print('Loading success, global_step is %s' % global_step)
                 else:
                     print('No checkpoint file found')

看起来无懈可击,这个错误无从下手。再仔细读一下这个Error,有没有一种checkpoint模型保存的参数名字和实际网络模型参数的名字不一样的感觉?(哈哈,反正我有)。看一下自己的checkpoint和网络参数名字:

checkpoint 参数名 参数名
此时我们会产生这样一个大胆的想法(小姐姐,我想...):难道checkpoint里的参数名字和我们网络的参数名字不一样吗??
可是如何去验证这样一个大胆的想法呢? 如何去看checkpoint里的参数名呢? 如何讨得小姐姐的芳心呢?(哦哦,跑题了QAQ)我们可以使用下面的代码:
import os
model_dir = '/home/mml/siamese_net/logs/train/'
from tensorflow.python import pywrap_tensorflow
#checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
checkpoint_path = os.path.join(model_dir, "model.ckpt-9999")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key)) 

运行完上述代码后,发现水落石出:

参数名和数值 参数名和数值

果然,checkpoint参数名和网络的参数名是不一样的,当然会导致无法在checkpoint里找到local5,因为checkpoint里只有siamese/local5,所以只要修改统一参数名,即可顺利消除错误。

上一篇下一篇

猜你喜欢

热点阅读