tensorflow学习笔记-会话机制(session)
在TensorFlow中,有两种用于运行计算图(graph)的会话(session)
-
tf.Session( )
-
tf.InteractivesSession( )
1. tf.Session( )
要使用tf,我们必须先构建(定义)graph,之后才能运行graph。
1.1 非交互式会话中的例子
import tensorflow as tf
# 构建graph
a = tf.add(3, 5)
# 运行graph
sess = tf.Session() # 创建tf.Session的一个对象sess
print(sess.run(a))
sess.close() # 关闭sess对象
一个session可能会占用一些资源,比如变量、队列和读取器(reader)。我们使用sess.close()关闭会话或者使用上下文管理器释放这些不再使用的资源。
1.2 建议的tf.Session( )写法
import tensorflow as tf
# 构建graph
matrix1 = tf.constant([[3., 3.]])
matrix2 = tf.constant([[2.], [2.]])
product = tf.matmul(matrix1, matrix2)
# 运行graph
with tf.Session() as sess: # 使用"with"语句,自动关闭会话
print(sess.run(product))
1.3 Fetch(取回)
在使用sess.run( )运行图时,我们可以传入fetches,用于取回某些操作或tensor的输出内容。fetches可以是list,tuple,namedtuple,dict中的任意一个。fetches可以是一个列表,在op的一次运行中一起获得(而不是逐个去获取 tensor)多个tensor值。
import tensoflow as tf
from collections import namedtuple
a = tf.constant([10, 20])
b = tf.constant([1.0, 2.0])
MyData = namedtuple('MyData', ['a', 'b'])
with tf.Session() as sess:
c = sess.run(a) # fetches可以为单个数a
d = sess.run([a, b]) # fetches可以为一个列表[a, b]
v = sess.run({'k1': MyData(a, b), 'k2': [b, a]})
print(c)
print(d)
print(d[0])
print(v)
'''
v is a dict and v['k1'] is a MyData namedtuple with the numpy array [10, 20] and the numpy array [1.0, 2.0]. v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array [10, 20].
'''
1.4 Feed(注入)
TensorFlow提供了feed注入机制, 它可以临时替代graph中任意op操作的输入tensor,可以对graph中任何操作提交补丁(直接插入一个tensor)。
feed机制只在调用它的方法内有效,方法结束,feed就会消失。最常见的用例是把某些特殊操作为feed注入的对象。你可以提供数据feed_dict,作为sess.run( )调用的参数。使用tf.placeholder( ),为某些操作的输入创建占位符。
import tensorflow as tf
import numpy as np
x = np.ones((2, 3))
y = np.ones((3, 2))
input1 = tf.placeholder(tf.int32)
input2 = tf.placeholder(tf.int32)
output = tf.matmul(input1, input2)
with tf.Session() as sess:
print(sess.run(output, feed_dict = {input1:x, input2:y}))
如果没有正确提供tf.placeholder( ),feed操作将产生错误。注意,feed注入的值不能是tf的tensor对象,应该是Python常量、字符串、列表、numpy ndarrays,或者TensorHandles。
1.5 分布式训练
从version 0.8之后,TensorFlow开始支持分布式计算的机器学习,而且TensorFlow会充分利用CPU、GPU等计算资源。如果检测到GPU,TensorFlow会优先使用GPU运行程序。用字符串标识设备,目前支持的设备包括:
“/cpu:0”:机器的第一个CPU。
“/gpu:0”:机器的第一个GPU, 如果有的话
“/gpu:1”:机器的第二个GPU, 以此类推
当计算机有多个GPU时,通过tf.device( ),我们可以指定用哪个GPU来执行。代码示例如下:
# 在with tf.device()下,构建graph
with tf.device("/gpu:0"):
a = tf.constant([[3., 3.]])
b = tf.constant([[2.], [2.]])
product = tf.matmul(a, b)
# 运行graph
with tf.Session() as sess:
print(sess.run(product))
2. tf.InteractivesSession( )
当python编辑环境是shell、IPython等交互式环境时,我们使用类tf.InteractiveSession代替类tf.Session,用方法tensor.eval( ),operation.run( ) 代替sess.run( ),这样可避免用一个变量sess来持有会话。其中更多地使用 tensor.eval(),所有的表达式都可以看作是tensor。
// 进入python3交互式环境
# python3
>>> import tensorflow as tf
// 进入一个交互式会话
>>> sess = tf.InteractiveSession()
>>> a = tf.constant(5.0)
>>> b = tf.constant(6.0)
>>> c = a * b
// We can just use 'c.eval()' without passing 'sess'
>>> print(c.eval())
>>> sess.close() // 关闭交互式会话
>>> exit() // 退出python3交互式环境