TensorFlow:保存和提取模型

2018-09-12  本文已影响285人  ACphart

描述

方法

保存模型

import tensorflow as tf
'''导入其它库'''
pass

'''搭建网络及其他准备工作'''
pass

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

'''设置模型保存器'''
m_saver = tf.train.Saver()

'''迭代训练'''
for i in range(n):
    '''训练模型'''
    pass

    if i % e == 0:
        '''每隔e代保存一次模型'''
        '''model_path 和model_name分别是保存模型文件的路径和文件名'''
        '''global_step设置i作为每个模型文件名的后缀'''
        m_saver.save(sess, "model_path/model_name", global_step=i)

各文件说明

提取模型

'''在一个新的python脚本文件中'''
import tensorflow as tf
'''导入其他库'''
pass

'''其他数据准备工作'''
'''这里不需要重新搭建模型'''

'''提取模型,首先提取计算图,这一步相当于搭建模型'''
saver = tf.train.import_meta_graph("model/mnist.ann-10000.meta")

with tf.Session() as sess:
    '''提取保存好的模型参数'''
    '''这里注意模型参数文件名要丢弃后缀.data-00000-of-00001'''
    saver.restore(sess, "model/mnist.ann-10000")

    '''通过张量名获取张量'''
    '''这里按张量名获取了我保存的一个模型的三个张量,并换上新的名字'''
    new_x = tf.get_default_graph().get_tensor_by_name("x:0")
    new_y = tf.get_default_graph().get_tensor_by_name("y:0")
    new_y_ = tf.get_default_graph().get_tensor_by_name("y_:0")
    '''现在可以进行计算了'''
    y_1 = sess.run(new_y_, feed_dict={new_x: new_x_data, new_y: new_y_data})

print(y_1)

其他

def __init__(  self,
               var_list=None,
               reshape=False,
               sharded=False,
               max_to_keep=5,
               keep_checkpoint_every_n_hours=10000.0,
               name=None,
               restore_sequentially=False,
               saver_def=None,
               builder=None,
               defer_build=False,
               allow_empty=False,
               write_version=saver_pb2.SaverDef.V2,
               pad_step_number=False,
               save_relative_paths=False,
               filename=None):
    """Creates a `Saver`."""

def save(  self,
           sess,
           save_path,
           global_step=None,
           latest_filename=None,
           meta_graph_suffix="meta",
           write_meta_graph=True,
           write_state=True):
     """Saves variables."""

def restore(self, sess, save_path):
    """Restores previously saved variables."""

def import_meta_graph(meta_graph_or_file, clear_devices=False,
                      import_scope=None, **kwargs):
  """Recreates a Graph saved in a `MetaGraphDef` proto."""
上一篇 下一篇

猜你喜欢

热点阅读