Tensorflow学习笔记

深层神经网络2

2017-11-03  本文已影响25人  Manfestain

使用验证集判断模型效果

为了评测神经网络模型在不同参数下的效果,一般会从训练集中抽取一部分作为验证数据。除了使用验证数据集,还可以采用交叉验证(cross validation )的方式验证模型效果,但是使用交叉验证会花费大量的时间。但在海量数据情况下,一般采用验证数据集的形式评测模型的效果。
一般采用的验证数据分布越接近测试数据分布,模型在验证数据上的表现越可以体现模型在测试数据上的保险。
使用滑动平均模型和指数衰减的学习率在一定程度上都是限制神经网络中参数更新的速度。
在处理复杂问题时,使用滑动平均模型、指数衰减的学习率和正则化损失可以明显提升模型的训练效果。


变量管理

Tensorflow提供了通过变量名称来创建或者获取一个变量的机制,避免了复杂神经网络频繁传递参数的情况。通过该机制,在不同的函数中可以直接通过变量的名字来使用变量,而不需要将变量通过参数的形式到处传递。
Tensorflow中通过变量名获取变量的机制主要通过tf.get_variable()tf.variable_scope()函数实现。

  1. tf.get_variable()
    该函数创建变量的方法和tf.Variable()函数的用法基本一样,提供维度信息(shape)以及初始化方法(initializer)的参数。该函数的变量名称是一个必填参数,函数会根据这个名字去创建或者获取变量。当已经有同名参数时,会报错。
  2. tf.variable_scope()
    该函数可以控制tf.get_variable()函数的语义。当tf.variable_scope()函数使用参数reuse=True生成上下文管理器时,这个上下文管理器内所有的tf.get_variable()函数会直接获取已经创建的变量。如果不存在,则报错;当reuse=False或者reuse=None创建上下文管理器时,tf.get_variable()操作将创建新的变量,如果同名变量已经存在,则报错。
    同时tf.variable_scope()函数可以嵌套。新建一个嵌套的上下文管理器但不指定reuse,这时的reuse的取值和外面一层保持一致。当退出reuse设置为True的上下文之后reuse的值又回到了False(内层reuse不设置)。

同时,tf.variable_scope()函数生成的上下文管理器也会创建一个Tensorflow中的命名空间,在命名空间内创建的变量名称都会带上这个命名空间名作为前缀。可以直接通过带命名空间名称的变量名来获取其它命名空间下的变量(创建一个名称为空的命名空间,并设置为reuse=True)。

with tf.variable_scope(" ", reuse=True):
     v5 = tf.get_variable("foo/bar/v", [1])
     print(v5.name)
===>v:0   # 0表示variable这个运算输出的第一个结果

Tensorflow模型持久化

将训练得到的模型保存下来,可以方便下次直接使用(避免重新训练花费大量的时间)。Tensorflow提供的持久化机制可以将训练之后的模型保存到文件中。
Tensorflow提供了tf.train.Saver类来保存和还原神经网络模型。当保存模型之后,目录下一般会出现三个文件,这是因为Tensorflow会将计算图的结构和图上参数值分开保存。

  1. model.ckpy.meta文件,保存了Tensorflow计算图的结构。
  2. model.ckpt文件,保存了Tensorflow程序每一个变量的取值。
  3. checkpoint文件,保存了一个目录下所有的模型文件列表。

保存模型
saver = tf.train.Saver()
saver.save(sess, "path/model.ckpt")
加载模型,此时不用进行变量的初始化过程
saver.restore(sess, "path/model.ckpt")
sess.run(result)

为了保存和加载部分变量,在声明tf.train.Saver类时可以提供一个列表来指定需要保存或加载的变量,saver = tf.train.Saver([v1])。同时,tf.train.Saver类也支持在保存或者加载时给变量重命名,如果直接加载就会导致程序报变量找不到的错误,Tensorflow提供通过字典将模型保存时的变量名和要加载的变量联系起来。

v = tf.Variable(tf.constant(1.0, shape=[1]), name='other-v1')
saver = tf.train.Saver({"v1": v})
将原先变量名为v1的变量加载到变量v中,变量v的名称为other-v1。

这样做的目的时为了方便使用变量的滑动平均值。因为每一个变量的滑动平均值是通过影子变量维护的,如果在加载模型时直接将影子变量映射到变量自身,就不需要在调用函数来获取变量的平均值了。

为了方便加载重命名滑动平均变量,tf.train.ExponentialMovingAverage类提供了variables_to_restore()函数来生成tf.train.Saver类所需要的变量重命名字典。

v = tf.Variable(0)
ema = tf.train.ExponentialMovingAverage(0.99)
saver = tf.train.Saver(ema.variable_to_restore())
with tf.Session() as sess:
     saver.restore(sess, "path/model.ckpt")
     sess.run(v)

有时候不需要类似于变量初始化、模型保存等辅助节点的信息,Tensorflow提供了convert_variables_to_constants()函数将计算图中的变量及其取值通过常量的方式保存。


持久化原理及数据格式

Tensorflow程序中所有计算都会被表达为计算图上的节点。

MetaGraphDef

Tensorflow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据,元图是由MetaGraphDef Protocol Buffer定义的,MetaGraphDef中的内容构成了Tensorflow持久化的第一个文件,也就是model.ckpt.meta文件。

SSTable

持久化Tensorflow中变量的取值,tf.Saver得到的model.ckpt文件保存了所有的变量,该文件使用SSTable格式存储的,相当于一个(key, value)列表。

CheckpointState

持久化的最后一个文件名叫checkpoint,这个文件是tf.train.Saver类自动生成且自动维护的。该文件中维护了一个由tf.train.Saver类持久化的所有Tensoflow模型文件的文件名,当某个模型文件被删除时,这个模型对应的文件名也会被移除,checkpoint中内容的格式为CheckpointState Protocol Buffer

上一篇 下一篇

猜你喜欢

热点阅读