机器学习深度学习的各种骚套路

Tensorflow的线程同步和停止

2020-06-07  本文已影响0人  大海龟啦啦啦

Tensorflow的多线程使用

Tensorflow的计算主要在使用CPU/GPU和内存,而数据读取涉及磁盘操作,速度远低于前者操作。因此通常会使用多个线程读取数据,然后使用一个线程来使用这些数据,QueueRunner就是来管理这些读写队列的线程,而只用QueueRunner的话有时候会造成这种同步的卡壳,导致程序被强行关闭,因此需要QueueRunner和Coordinator的配合来进行调用,共同协作来停止绘画中的所有线程,并向在等待所有工作线程终止的程序报告。
示例如下所示:

import tensorflow as tf

q = tf.FIFOQueue(1000 , "float32")
counter = tf.Variable(0.0)
#   函数原型是tf.assign_add(ref,value,use_locking=None,name=None),作用是更新ref的值,通过增加value,即:ref = ref + value
add_op = tf.assign_add(counter , tf.constant(1.0))
#   通过enqueue函数将counter变量加入队列
enqueueData_op = q.enqueue(counter)

#   Session 是 Tensorflow 为了控制,和输出文件的执行的语句,意思就是将其加入tensorflow的对话,
#   运行 sess.run() 可以获得你要得知的运算结果, 或者是你所要运算的部分。
sess = tf.Session()
#   tf.train.QueueRunner是创建并运行线程的函数,q代表之前创建的队列,enqueue_ops代表需要加入到q的线程
#   add_op表示的是计数,enqueueData_op表示的是加入队列,这里实际创建了4个线程,两个增加计数,两个执行入队
#   这一步的作用是用多个线程向队列添加数据,这样的话就可以减少由于数据读取的慢速度影响程序整体的运行速度
qr = tf.train.QueueRunner(q , enqueue_ops=[add_op , enqueueData_op] * 2)
sess.run(tf.global_variables_initializer())

#   开启一个协调器
coord = tf.train.Coordinator()
#   启动队列运行器线程
enqueue_threads = qr.create_threads(sess , coord=coord , start=True)


for i in range(10):
    print(sess.run(q.dequeue()))

#   完成后,要求线程停止
coord.request_stop()
#   并等待这些线程完成
coord.join(enqueue_threads)
上一篇 下一篇

猜你喜欢

热点阅读