Tensorflow store/restore model

2017-04-06  本文已影响0人  wlysola

最近在用 java 改写一个用 python 编写的 model,遇到了有关模型保存与恢复的问题,发现网上的资料有些混乱,在这里做一些记录。

.ckpt

1. .ckpt 全称为 checkpoint,代表着一个检查点,即为 model 训练过程中的一个快照,可能是在训练开始,也可能是在训练完成。

2. .ckpt 是由 Saver 调用 save 产生的:

saver.save(sess,"/tmp/model.ckpt")

详见 save_demo

3. 由 Saver 调用 restore 来复原 model 的数据:

saver.restore(sess,path)

详见 restore_demo

注意这里,复原的只有数据,不含 graph 信息。

4. .ckpt 不是单独的一个文件,而是一系列文件。

.ckpt 的一系列文件

其内部包含了:

①checkpoint: .ckpt 的标记信息。

.data: model 中 graph 的数据,包括各种变量,不含常量。

③.index: 索引信息。

.meta: graph 信息。

在这里要搞明白一点,一个 model 是由 graph(④) + 数据(②) 组成的。

graph 代表着执行逻辑,在 tensorflow 中,每个算子用一个 node 来表示,众多 node 组合起来便是一张图(graph),也就是我们的执行逻辑,而这些执行逻辑在 Saver 调用 save 时,会被存到 .meta 中(不含数据)。各个 node 中含有各种参数(变量,比如训练的权重),这些参数则被存储到 .data 中。graph 与数据是分别存储的。

tf.train.import_meta_graph

该方法只能恢复 graph,不恢复数据。

注意与上面提及的 saver.restore 区分,saver.restore 只恢复数据,不恢复 graph。

recover model

现在我们来讨论下,如何能恢复一个model。前面已经提过了,一个 model 由 graph 和 数据组成,所以只要能恢复这两部分就可以了,依据恢复的方法不同,可以分为两类。

①分别恢复 graph 和数据:

对于数据来说,可以用 saver.restore 来恢复。

对于graph来说,依据恢复方法不同可以分为两种:

A.硬编码恢复:在调用方法中,重新书写 graph 信息。

B. .meta 恢复:通过调用 tf.train.import_meta_graph 方法获得 graph,并配合 get_tensor_by_name 的方法来调用 model 中特定的算子(node)。

saver = tf.train.import_meta_graph('~/tmp/model.ckpt-1000.meta')

graph = tf.get_default_graph()

input = graph.get_tensor_by_name('input:0')

.meta 恢复 demo

② freezing(固化):

该方法将变量(训练的权重)固化在 graph 中,即用常量来替换 graph 中的变量,从而达到无需恢复数据,直接调用 graph 即可。权重一旦被固化就不能再修改,该方法一般用于生产环境。

freezing demo

注:笔者在测试 Java API 时,其只支持调用 freezing 后的图。


References:

TensorFlow学习系列(三):保存/恢复和混合多个模型

stackoverflow上一系列值得思考的问题

上一篇下一篇

猜你喜欢

热点阅读