神经网络与深度学习机器学习机器学习与计算机视觉

在Windows下使用Tensorflow Object Det

2017-07-09  本文已影响7458人  Daisy丶

Tensorflow Object Detection API是Tensorflow官方发布的一个建立在TensorFlow之上的开源框架,可以轻松构建,训练和部署对象检测模型。TensorFlow官方使用TensorFlow Slim项目框实现了近年来提出的多种优秀的深度卷积神经网络框架。

Tensorflow Object Detection API可以选择的模型:

Githubhttps://github.com/tensorflow/models/tree/master/object_detection

在本文中,我们实现了在Windows环境下运行该框架的流程。在此之前我们要使用相关的卷积模型,需要自行编译作者指定的Caffe,不同的框架使用的Caffe版本也不尽相同。而基于其他深度学习框架的代码受制于作者水平的不同,可用性与效率也不尽相同,因此TOD API在Tensorflow上提供了了一套标准化的编写模式,既有利于使用,也有为编写其他模型提供了例子。

环境

首先我们安装Tensorflow,最新的版本为1.2。在python 3.5+使用Tensorflow非常的简单,不需要过多的流程,只需要使用pip进行安装,所有相关的依赖就会自动安装完成。

# For CPU
pip install tensorflow
# For GPU
pip install tensorflow-gpu

其次官方要求下列包,我们一同使用pip进行安装。

pip install pillow
pip install lxml
pip install jupyter
pip install matplotlib

Tensorflow Object Detection API使用Protobufs来配置模型和训练参数。 在使用框架之前,必须编译Protobuf库。对于protobuf,在Linux下我们可以使用apt-get安装,在Windows下我们可以直接下载已经编译好的版本,这里我们选择下载列表中的protoc-3.3.0-win32.zip。

Githubhttps://github.com/google/protobuf/releases

我们将bin文件夹加入到环境变量中,然后在CMD执行protco命令,可以看到protobuf要求输入文件。

protoc.jpg

接下来我们切换到models目录下,使用protoc命令编译.proto文件

# From tensorflow/models/
protoc object_detection/protos/*.proto --python_out=.

我们可以看见.proto文件已经被编译为了.py文件。

proto.jpg

官方提供了一个object_detection_tutorial.ipynb文件,这个Demo会自动下载并执行最小最快的模型Single Shot Multibox Detector (SSD) with MobileNet。检测结果如下:

1.png 2.png

为了方便在项目中使用,我们重写了一个Python文件,其中网络模型可以从下面的地址下载,每一个模型都有一个frozen_inference_graph.pb文件。代码与运行结果如下:

Tensorflow detection model:
https://github.com/tensorflow/models/blob/master/object_detection/g3doc/detection_model_zoo.md

# coding:utf8
import os
import sys
import cv2
import numpy as np
import tensorflow as tf
sys.path.append("..")

from utils import label_map_util
from utils import visualization_utils as vis_util


class TOD(object):
    def __init__(self):
        # Path to frozen detection graph. This is the actual model that is used for the object detection.
        self.PATH_TO_CKPT = 'frozen_inference_graph.pb'

        # List of the strings that is used to add correct label for each box.
        self.PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')

        self.NUM_CLASSES = 90

        self.detection_graph = self._load_model()
        self.category_index = self._load_label_map()

    def _load_model(self):
        detection_graph = tf.Graph()
        with detection_graph.as_default():
            od_graph_def = tf.GraphDef()
            with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')
        return detection_graph

    def _load_label_map(self):
        label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)
        categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=self.NUM_CLASSES, use_display_name=True)
        category_index = label_map_util.create_category_index(categories)
        return category_index

    def detect(self, image):
        with self.detection_graph.as_default():
            with tf.Session(graph=self.detection_graph) as sess:
                # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
                image_np_expanded = np.expand_dims(image, axis=0)
                image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
                # Each box represents a part of the image where a particular object was detected.
                boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
                # Each score represent how level of confidence for each of the objects.
                # Score is shown on the result image, together with the class label.
                scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
                classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
                num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')
                # Actual detection.
                (boxes, scores, classes, num_detections) = sess.run(
                    [boxes, scores, classes, num_detections],
                    feed_dict={image_tensor: image_np_expanded})
                # Visualization of the results of a detection.
                vis_util.visualize_boxes_and_labels_on_image_array(
                    image,
                    np.squeeze(boxes),
                    np.squeeze(classes).astype(np.int32),
                    np.squeeze(scores),
                    self.category_index,
                    use_normalized_coordinates=True,
                    line_thickness=8)

        while True:
            cv2.namedWindow("detection", cv2.WINDOW_NORMAL)
            cv2.imshow("detection", image)
            if cv2.waitKey(110) & 0xff == 27:
                break


if __name__ == '__main__':
    image = cv2.imread('dog.jpg')
    detecotr = TOD()
    detecotr.detect(image)

test.jpg
上一篇下一篇

猜你喜欢

热点阅读