我爱编程

tensorflow学习笔记-会话机制(session)

2018-05-08  本文已影响259人  听风1996

在TensorFlow中,有两种用于运行计算图(graph)的会话(session)

1. tf.Session( )

要使用tf,我们必须先构建(定义)graph,之后才能运行graph。

1.1 非交互式会话中的例子

import tensorflow as tf

# 构建graph
a = tf.add(3, 5) 

# 运行graph
sess = tf.Session()  # 创建tf.Session的一个对象sess
print(sess.run(a)) 

sess.close()         # 关闭sess对象

一个session可能会占用一些资源,比如变量、队列和读取器(reader)。我们使用sess.close()关闭会话或者使用上下文管理器释放这些不再使用的资源。

1.2 建议的tf.Session( )写法

import tensorflow as tf  

# 构建graph
matrix1 = tf.constant([[3., 3.]])  
matrix2 = tf.constant([[2.], [2.]])  

product = tf.matmul(matrix1, matrix2)  

# 运行graph
with tf.Session() as sess:          # 使用"with"语句,自动关闭会话
    print(sess.run(product))  

1.3 Fetch(取回)

在使用sess.run( )运行图时,我们可以传入fetches,用于取回某些操作或tensor的输出内容。fetches可以是list,tuple,namedtuple,dict中的任意一个。fetches可以是一个列表,在op的一次运行中一起获得(而不是逐个去获取 tensor)多个tensor值。

import tensoflow as tf
from collections import namedtuple

a = tf.constant([10, 20])
b = tf.constant([1.0, 2.0])
MyData = namedtuple('MyData', ['a', 'b'])

with tf.Session() as sess:
    c = sess.run(a)            # fetches可以为单个数a
    d = sess.run([a, b])       # fetches可以为一个列表[a, b]
    v = sess.run({'k1': MyData(a, b), 'k2': [b, a]}) 

    print(c)
    print(d)
    print(d[0])
    print(v) 
'''
v is a dict and v['k1'] is a MyData namedtuple with the numpy array [10, 20] and the numpy array [1.0, 2.0]. v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array [10, 20].
'''

1.4 Feed(注入)

TensorFlow提供了feed注入机制, 它可以临时替代graph中任意op操作的输入tensor,可以对graph中任何操作提交补丁(直接插入一个tensor)。

feed机制只在调用它的方法内有效,方法结束,feed就会消失。最常见的用例是把某些特殊操作为feed注入的对象。你可以提供数据feed_dict,作为sess.run( )调用的参数。使用tf.placeholder( ),为某些操作的输入创建占位符。

import tensorflow as tf
import numpy as np

x = np.ones((2, 3))
y = np.ones((3, 2)) 

input1 = tf.placeholder(tf.int32)
input2 = tf.placeholder(tf.int32)

output = tf.matmul(input1, input2)

with tf.Session() as sess:
    print(sess.run(output, feed_dict = {input1:x, input2:y}))

如果没有正确提供tf.placeholder( ),feed操作将产生错误。注意,feed注入的值不能是tf的tensor对象,应该是Python常量、字符串、列表、numpy ndarrays,或者TensorHandles。

1.5 分布式训练

从version 0.8之后,TensorFlow开始支持分布式计算的机器学习,而且TensorFlow会充分利用CPU、GPU等计算资源。如果检测到GPU,TensorFlow会优先使用GPU运行程序。用字符串标识设备,目前支持的设备包括:

“/cpu:0”:机器的第一个CPU。

“/gpu:0”:机器的第一个GPU, 如果有的话

“/gpu:1”:机器的第二个GPU, 以此类推

当计算机有多个GPU时,通过tf.device( ),我们可以指定用哪个GPU来执行。代码示例如下:

# 在with tf.device()下,构建graph
with tf.device("/gpu:0"):
    a = tf.constant([[3., 3.]])
    b = tf.constant([[2.], [2.]])
    product = tf.matmul(a, b)

# 运行graph
with tf.Session() as sess:    
    print(sess.run(product))

2. tf.InteractivesSession( )

当python编辑环境是shell、IPython等交互式环境时,我们使用类tf.InteractiveSession代替类tf.Session,用方法tensor.eval( ),operation.run( ) 代替sess.run( ),这样可避免用一个变量sess来持有会话。其中更多地使用 tensor.eval(),所有的表达式都可以看作是tensor。

// 进入python3交互式环境
# python3

>>> import tensorflow as tf  

// 进入一个交互式会话
>>> sess = tf.InteractiveSession()

>>> a = tf.constant(5.0)
>>> b = tf.constant(6.0)
>>> c = a * b

// We can just use 'c.eval()' without passing 'sess'
>>> print(c.eval()) 

>>> sess.close()   // 关闭交互式会话

>>> exit()        // 退出python3交互式环境
上一篇下一篇

猜你喜欢

热点阅读