[tf]模型存储和加载

2018-12-07  本文已影响0人  VanJordan

saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
saver = tf.train.Saver() 默认是保存默认图上的Variable数据。当然也可以指定保存那些Variable数据,tf.train.Saver([var_list])

模型的加载

loader = tf.train.Saver()
loader.restore(sess,model_dir)
v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')

# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})

# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})

表示的意思是需要加载的变量是embedding

def setup_loader(self):
    self.loader = tf.train.Saver(self.var_list)

def load_session(self, itr):
        self.loader.restore(self.sess, self.model_name + "_weights/" + self.dataset + "/" + itr + ".ckpt")
-----------------------TransE model中的self.var_list---------------------
self.rel_emb = tf.get_variable(name="rel_emb", initializer=tf.random_uniform(shape=[self.num_rel, self.params.emb_size], minval=-sqrt_size, maxval=sqrt_size))
self.ent_emb = tf.get_variable(name="ent_emb", initializer=tf.random_uniform(shape=[self.num_ent, self.params.emb_size], minval=-sqrt_size, maxval=sqrt_size))
self.var_list = [self.rel_emb, self.ent_emb]

模型的保存

saver = tf.train.Saver(max_to_keep=0)
saver.save(self.sess, filename)
def setup_saver(self):
    self.saver = tf.train.Saver(max_to_keep=0)

def save_model(self, itr):
    filename = self.model_name + "_weights/" + self.dataset + "/" + str(itr) + ".ckpt"
    if not os.path.exists(os.path.dirname(filename)):
        os.makedirs(os.path.dirname(filename))
    self.saver.save(self.sess, filename)

例子:保存模型

# construct graph
v1 = tf.Variable([0], name='v1')
v2 = tf.Variable([0], name='v2')
# run graph
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.save(sess, 'ckp')

with tf.Session() as sess:
    saver = tf.import_meta_graph('ckp.meta')
    saver.restore(sess, 'ckp')

当执行Saver.saver操作的时候,在文件系统中生成如下文件:

├── checkpoint
├── ckp.data-00000-of-00001
├── ckp.index
├── ckp.meta
上一篇下一篇

猜你喜欢

热点阅读