效率工具

采用目标跟踪算法实现对移动目标的自动化标注工作

2021-02-01  本文已影响0人  gaoshine

采用目标跟踪算法实现对移动目标的自动化标注工作

使用目标检测(object detect)算法,需要做图片中识别目标的标注工作,工作量极为庞大.例如:流行的yolo模型,自己的训练集需要做标注,并保存为voc格式,一般采用labeling工具完成.
我们一般情况下,客户会提供视频的素材,需要对视频中指定标的物实现检测和识别,通常会把视频分解成一帧帧的图片,手工标注图片中的标的物,最后生成VOC格式的标注文件和图片,最后再做训练.
我们尝试使用opencv中目标跟踪算法,通过跟踪标定目标,实现自动标注目标,自动生成标注格式的voc文件,完成辅助标注任务,这样不仅可以大大减轻人工标注的繁重任务,还可以大幅提升训练效率.

1612150090471532.gif

0. 必要的库

# 采用目标跟踪算法实现对移动目标的自动化标注代码
# python3 autolabbeling.py --video test.mov --tracker csrt

# import the necessary packages
from imutils.video import VideoStream
from imutils.video import FPS
from lxml import etree
import argparse
import imutils
import time
import cv2
import os

1.运行参数

usage: autolabbeling.py  [-h] [-v VIDEO] [-t TRACKER] [--name NAME]
optional arguments:
  -h, --help            show this help message and exit
  -v VIDEO, --video VIDEO
                        path to input video file
  -t TRACKER, --tracker TRACKER
                        OpenCV object tracker type
  --name NAME           图片类别.

# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-v", "--video", type=str,
                help="path to input video file")
ap.add_argument("-t", "--tracker", type=str, default="kcf",
                help="OpenCV object tracker type")
ap.add_argument(
    "--name",
    default="person",
    help="图片类别.")
args = vars(ap.parse_args())
# 标注格式的VOC文件的类

class GEN_Annotations:
    def __init__(self, filename):
        self.root = etree.Element("annotation")
        child1 = etree.SubElement(self.root, "folder")
        child1.text = "VOC2007"
        child2 = etree.SubElement(self.root, "filename")
        child2.text = filename
        child3 = etree.SubElement(self.root, "source")
        child4 = etree.SubElement(child3, "annotation")
        child4.text = "PASCAL VOC2007"
        child5 = etree.SubElement(child3, "database")
        child5.text = "Unknown"
        child6 = etree.SubElement(child3, "image")
        child6.text = "flickr"
        child7 = etree.SubElement(child3, "flickrid")
        child7.text = "35435"

    def set_size(self, witdh, height, channel):
        size = etree.SubElement(self.root, "size")
        widthn = etree.SubElement(size, "width")
        widthn.text = str(witdh)
        heightn = etree.SubElement(size, "height")
        heightn.text = str(height)
        channeln = etree.SubElement(size, "depth")
        channeln.text = str(channel)

    def savefile(self, filename):
        tree = etree.ElementTree(self.root)
        tree.write(filename, pretty_print=True,
                   xml_declaration=False, encoding='utf-8')

    def add_pic_attr(self, label, xmin, ymin, xmax, ymax):
        object = etree.SubElement(self.root, "object")
        namen = etree.SubElement(object, "name")
        namen.text = label
        bndbox = etree.SubElement(object, "bndbox")
        xminn = etree.SubElement(bndbox, "xmin")
        xminn.text = str(xmin)
        yminn = etree.SubElement(bndbox, "ymin")
        yminn.text = str(ymin)
        xmaxn = etree.SubElement(bndbox, "xmax")
        xmaxn.text = str(xmax)
        ymaxn = etree.SubElement(bndbox, "ymax")
        ymaxn.text = str(ymax)


2.目标跟踪

# opencv的跟踪算法
tracker = cv2.TrackerCSRT_create()
# initialize the bounding box coordinates of the object we are going to track
initBB = None

# 初始化视频,打开视频文件
if not args.get("video", False):
    print("[INFO] starting video stream...")
    vs = VideoStream(src=0).start()
    time.sleep(1.0)

# 没有视频文件,则开启摄像头
else:
    vs = cv2.VideoCapture(args["video"])

# 初始化FPS,用于评估
fps = None
# 主循环代码
while True:

    frame = vs.read()
    frame = frame[1] if args.get("video", False) else frame
    if frame is None:
        break
    frame = imutils.resize(frame, width=800)
    (H, W) = frame.shape[:2]

    # check to see if we are currently tracking an object
    if initBB is not None:
        # grab the new bounding box coordinates of the object
        (success, box) = tracker.update(frame)

        # check to see if the tracking was a success
        if success:
            (x, y, w, h) = [int(v) for v in box]
            
            # 按照voc格式文件生成imgs目录和anns目录
            if not os.path.exists("imgs"):
                os.makedirs("imgs")
            if not os.path.exists("anns"):
                os.makedirs("anns")
            # 图片保存到imgs目录
            filename = str(time.time()) + ".jpg"
            # 标注文件生成,保存到anns目录
            anno = GEN_Annotations(filename)
            anno.set_size(W, H, 3)
            anno.add_pic_attr(args["name"], x, y, x + w, y + h)
            cv2.imwrite("imgs/" + filename, frame)
            anno.savefile("anns/" + os.path.splitext(filename)[0] + ".xml")
            
            cv2.rectangle(frame, (x, y), (x + w, y + h),
                          (0, 255, 0), 2)

        # 更新FPS计数器
        fps.update()
        fps.stop()

        # 显示字幕信息
        info = [
            ("Lable", args["name"]),            
            ("Tracker", args["tracker"]),
            ("Success", "Yes" if success else "No"),
            ("FPS", "{:.2f}".format(fps.fps())),
        ]

        for (i, (k, v)) in enumerate(info):
            text = "{}: {}".format(k, v)
            cv2.putText(frame, text, (10, H - ((i * 20) + 20)),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)

    # show the output frame
    cv2.imshow("Frame", frame)
    key = cv2.waitKey(1) & 0xFF

    # if the 's' key is selected, we are going to "select" a bounding
    # box to track
    if key == ord("s"):
        # select the bounding box of the object we want to track (make
        # sure you press ENTER or SPACE after selecting the ROI)
        initBB = cv2.selectROI("Frame", frame, fromCenter=False,
                               showCrosshair=True)

        # start OpenCV object tracker using the supplied bounding box
        # coordinates, then start the FPS throughput estimator as well
        tracker.init(frame, initBB)
        fps = FPS().start()

    # if the `q` key was pressed, break from the loop
    elif key == ord("q"):
        break

# if we are using a webcam, release the pointer
if not args.get("video", False):
    vs.stop()

# otherwise, release the file pointer
else:
    vs.release()

# close all windows
cv2.destroyAllWindows()

上一篇下一篇

猜你喜欢

热点阅读