[tf]理解图与会话

2018-12-12  本文已影响4人  VanJordan

引用

一些关键问题

# 1. Using Graph.as_default():
g = tf.Graph()
with g.as_default():
  c = tf.constant(5.0)
  assert c.graph is g

# 2. Constructing and making default:
with tf.Graph().as_default() as g:
  c = tf.constant(5.0)
  assert c.graph is g

什么是图

self.model = model
self.graph = tf.Graph()
with self.graph.as_default():
    self.sess = tf.Session()
    with self.sess.as_default():
        initializer = tf.contrib.layers.xavier_initializer(uniform = True)
        with tf.variable_scope("model", reuse=None, initializer = initializer):
            self.trainModel = self.model(config = self)
            if self.optimizer != None:
                pass
            elif self.opt_method == "Adagrad" or self.opt_method == "adagrad":
                self.optimizer = tf.train.AdagradOptimizer(learning_rate = self.alpha, initial_accumulator_value=1e-20)
            elif self.opt_method == "Adadelta" or self.opt_method == "adadelta":
                self.optimizer = tf.train.AdadeltaOptimizer(self.alpha)
            elif self.opt_method == "Adam" or self.opt_method == "adam":
                self.optimizer = tf.train.AdamOptimizer(self.alpha)
            else:
                self.optimizer = tf.train.GradientDescentOptimizer(self.alpha)
            grads_and_vars = self.optimizer.compute_gradients(self.trainModel.loss)
            self.train_op = self.optimizer.apply_gradients(grads_and_vars)
        self.saver = tf.train.Saver()
        self.sess.run(tf.initialize_all_variables())

空图

计算图的初始状态,并非是一个空图。实现添加了两个特殊的节点:Source与Sink节点,分别表示DAG图的起始节点与终止节点。其中,Source的id为0,Sink的id为1;依次论断,普通OP节点的id将大于1。

另外,Source与Sink之间,通过连接「控制依赖」的边,保证计算图的执行始于Source节点,终于Sink节点。它们之前连接的控制依赖边,其src_output, dst_input值都为-1。

习惯上,仅包含Source与Sink节点的计算图也常常称为空图。

图的可视化

# Build your graph.
x = tf.constant([[37.0, -23.0], [1.0, 4.0]])
w = tf.Variable(tf.random_uniform([2, 2]))
y = tf.matmul(x, w)
# ...
loss = ...
train_op = tf.train.AdagradOptimizer(0.01).minimize(loss)

with tf.Session() as sess:
  # `sess.graph` provides access to the graph used in a <a href="../api_docs/python/tf/Session"><code>tf.Session</code></a>.
  writer = tf.summary.FileWriter("/tmp/log/...", sess.graph)

  # Perform your computation...
  for i in range(1000):
    sess.run(train_op)
    # ...

  writer.close()

创建多个图

The default graph is a property of the current thread. If you create a new thread, and wish to use the default graph in that thread, you must explicitly add a with g.as_default(): in that thread's function.

with g_1.as_default():
  # Operations created in this scope will be added to `g_1`.
  c = tf.constant("Node in g_1")

  # Sessions created in this scope will run operations from `g_1`.
  sess_1 = tf.Session()

g_2 = tf.Graph()
with g_2.as_default():
  # Operations created in this scope will be added to `g_2`.
  d = tf.constant("Node in g_2")

# Alternatively, you can pass a graph when constructing a <a href="../api_docs/python/tf/Session"><code>tf.Session</code></a>:
# `sess_2` will run operations from `g_2`.
sess_2 = tf.Session(graph=g_2)

assert c.graph is g_1
assert sess_1.graph is g_1

assert d.graph is g_2
assert sess_2.graph is g_2

执行过程

图实例

with tf.Graph().as_default() as g:
    c = tf.constant(5.0)
    assert c.graph is g

图上的device方法

with g.device('/gpu:0'):
# All OPs constructed here will be placed on GPU 0.

为什么会需要会话

默认会话

hello = tf.constant('hello, world')
sess = tf.Session()
with sess.as_default():
     print(hello.eval())
sess.close()

张量求值

如上例代码,hello.eval() 等价于tf.get_default_session().run(hello)。其中,Tensor.
eval 如下代码实现。

class Tensor(_TensorLike):
    def eval(self, feed_dict=None, session=None):
        if session is None:
            session = get_default_session()
        return session.run(tensors, feed_dict)

会话类型

上一篇 下一篇

猜你喜欢

热点阅读