【Python】多进程队列问题

2018-01-25  本文已影响302人  下里巴人也

最近在使用Python多进程时,遇到一个偶现队列死锁问题:
process1 是生产者,进程池中有两个进程process2_1,process2_2是消费者。通过两个队列Q1和Q2来传输数据,具体问题如下:

用的Queue包是:

from multiprocessing import Queue

Process1代码如下:

    def run(self):
      thread.start_new_thread(self.__recv_msg_from_main,("recv_main_process", "11"))
      # start a video
      video_name = "video" + str(ConfigManager.get_client_key())
      video_capture = WebcamVideoStream(ConfigManager.get_sources(),
                                        ConfigManager.get_width(),
                                        ConfigManager.get_height()).start()
      while True:
          if self.__gate_open:
              time.sleep(0.04)
              ret, origin_frame = video_capture.read()

              if not ret:
                  logger.warn("capture failed!")
                  continue

              logger.debug("video capture __input_q.size: %d, __output_q.size: %d", self.__input_queue.qsize(), self.__output_queue.qsize())

              # frame pre handle
              frame = self.__frame_pre_handle(origin_frame)

              self.__in_queue.put((self.__operation_id, self.__frame_id, frame))
              self.__frame_id += 1

              if ConfigManager.get_show_video() == 1:
                  try:
                      frame_id, output_frame = self.__out_queue.get(False, 1)
                  except Empty:
                      logger.debug("video capture output queue is empty")
                      continue

                  cv2.imshow(video_name, output_frame)
                  if cv2.waitKey(1) & 0xFF == ord('q'):
                      logger.warn("exit for waitKey!")
                      break

          else:
              time.sleep(1)
              logger.fatal("video capture process exit abnomal !!!")

Process2_1,Process2_2代码如下:

def worker(input_q, detect_q, update_q):
  # Load a (frozen) Tensorflow model into memory.
  detection_graph = tf.Graph()
  with detection_graph.as_default():
      od_graph_def = tf.GraphDef()
      with tf.gfile.GFile(PathManager.get_ckpt_path(), 'rb') as fid:
          serialized_graph = fid.read()
          od_graph_def.ParseFromString(serialized_graph)
          tf.import_graph_def(od_graph_def, name='')

      sess = tf.Session(graph=detection_graph)

  category_index = ComponentManager.get_category_index()
  width = ConfigManager.get_width()
  height = ConfigManager.get_height()

  while True:
      operation_id, frame_id, frame = input_q.get()
      frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
      logger.debug("video parse in_queue_size: %d", n, input_q.qsize())

      updated_frame, score, classes, boxes = parse_origin_video_frame(frame_rgb,
                                                                      sess,
                                                                      detection_graph,
                                                                      category_index)
      items = item_detect(score, classes, boxes, category_index)
      if items:
          valid_frame = Frame(operation_id, frame_id)
          for item in items:
              logger.debug("frame_id: %d, item[%d %s %d (%d %d)]", frame_id, item.get_id(),
                           item.get_name(), item.get_score(), width - item.get_x(), height - item.get_y())
              valid_frame.add_item(item)
          detect_q.put(valid_frame)

      # for imshow
      if ConfigManager.get_show_video() == 1:
          output_rgb = cv2.cvtColor(updated_frame, cv2.COLOR_RGB2BGR)
          update_q.put((frame_id, output_rgb))
  sess.close()

class VideoParserProcess(object):
  def __init__(self, in_queue, detected_queue, update_queue):
      self.__video_name = 'Video' + str(ConfigManager.get_sources())
      self.__in_queue = in_queue
      self.__detected_queue = detected_queue
      self.__update_queue = update_queue

  def run(self):
      self.__pool = Pool(ConfigManager.get_worker_num(), worker, (self.__in_queue, self.__detected_queue, self.__update_queue))

  def destory(self):
      self.__pool.terminate()

上面代码中用到了Queue类的三个方法:put,get,qsize
源码如下:

  def put(self, obj, block=True, timeout=None):
      assert not self._closed
      if not self._sem.acquire(block, timeout):
          raise Full

      self._notempty.acquire()
      try:
          if self._thread is None:
              self._start_thread()
          self._buffer.append(obj)
          self._notempty.notify()
      finally:
          self._notempty.release()

  def get(self, block=True, timeout=None):
      if block and timeout is None:
          self._rlock.acquire()
          try:
              res = self._recv()
              self._sem.release()
              return res
          finally:
              self._rlock.release()

      else:
          if block:
              deadline = time.time() + timeout
          if not self._rlock.acquire(block, timeout):
              raise Empty
          try:
              if block:
                  timeout = deadline - time.time()
                  if not self._poll(timeout):
                      raise Empty
              elif not self._poll():
                  raise Empty
              res = self._recv()
              self._sem.release()
              return res
          finally:
              self._rlock.release()

  def qsize(self):
      # Raises NotImplementedError on Mac OSX because of broken sem_getvalue()
      return self._maxsize - self._sem._semlock._get_value()

几个接口都有锁,实际把打印qsize的地方,都去掉,问题就没复现。。 待继续看看源码。初入python,要学的还很多。

上一篇下一篇

猜你喜欢

热点阅读