采用目标跟踪算法实现对移动目标的自动化标注工作
2021-02-01 本文已影响0人
gaoshine
采用目标跟踪算法实现对移动目标的自动化标注工作
1612150090471532.gif使用目标检测(object detect)算法,需要做图片中识别目标的标注工作,工作量极为庞大.例如:流行的yolo模型,自己的训练集需要做标注,并保存为voc格式,一般采用labeling工具完成.
我们一般情况下,客户会提供视频的素材,需要对视频中指定标的物实现检测和识别,通常会把视频分解成一帧帧的图片,手工标注图片中的标的物,最后生成VOC格式的标注文件和图片,最后再做训练.
我们尝试使用opencv中目标跟踪算法,通过跟踪标定目标,实现自动标注目标,自动生成标注格式的voc文件,完成辅助标注任务,这样不仅可以大大减轻人工标注的繁重任务,还可以大幅提升训练效率.
0. 必要的库
- 采用opencv采集视频数据.
- 通过lxml库中etree读写xml文件,生成voc格式的标注文件.
- 采用opencv的TrackerCSRT实现目标跟踪.
- 按照anno(标注格式文件)和img(标注图片)生成标注所需文件
# 采用目标跟踪算法实现对移动目标的自动化标注代码
# python3 autolabbeling.py --video test.mov --tracker csrt
# import the necessary packages
from imutils.video import VideoStream
from imutils.video import FPS
from lxml import etree
import argparse
import imutils
import time
import cv2
import os
1.运行参数
usage: autolabbeling.py [-h] [-v VIDEO] [-t TRACKER] [--name NAME]
optional arguments:
-h, --help show this help message and exit
-v VIDEO, --video VIDEO
path to input video file
-t TRACKER, --tracker TRACKER
OpenCV object tracker type
--name NAME 图片类别.
- VIDEO --> 视频文件
- TRACKER --> 跟踪算法(默认:kcf)
- NAME --> 标注类型标签(默认:person)
# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-v", "--video", type=str,
help="path to input video file")
ap.add_argument("-t", "--tracker", type=str, default="kcf",
help="OpenCV object tracker type")
ap.add_argument(
"--name",
default="person",
help="图片类别.")
args = vars(ap.parse_args())
# 标注格式的VOC文件的类
class GEN_Annotations:
def __init__(self, filename):
self.root = etree.Element("annotation")
child1 = etree.SubElement(self.root, "folder")
child1.text = "VOC2007"
child2 = etree.SubElement(self.root, "filename")
child2.text = filename
child3 = etree.SubElement(self.root, "source")
child4 = etree.SubElement(child3, "annotation")
child4.text = "PASCAL VOC2007"
child5 = etree.SubElement(child3, "database")
child5.text = "Unknown"
child6 = etree.SubElement(child3, "image")
child6.text = "flickr"
child7 = etree.SubElement(child3, "flickrid")
child7.text = "35435"
def set_size(self, witdh, height, channel):
size = etree.SubElement(self.root, "size")
widthn = etree.SubElement(size, "width")
widthn.text = str(witdh)
heightn = etree.SubElement(size, "height")
heightn.text = str(height)
channeln = etree.SubElement(size, "depth")
channeln.text = str(channel)
def savefile(self, filename):
tree = etree.ElementTree(self.root)
tree.write(filename, pretty_print=True,
xml_declaration=False, encoding='utf-8')
def add_pic_attr(self, label, xmin, ymin, xmax, ymax):
object = etree.SubElement(self.root, "object")
namen = etree.SubElement(object, "name")
namen.text = label
bndbox = etree.SubElement(object, "bndbox")
xminn = etree.SubElement(bndbox, "xmin")
xminn.text = str(xmin)
yminn = etree.SubElement(bndbox, "ymin")
yminn.text = str(ymin)
xmaxn = etree.SubElement(bndbox, "xmax")
xmaxn.text = str(xmax)
ymaxn = etree.SubElement(bndbox, "ymax")
ymaxn.text = str(ymax)
2.目标跟踪
# opencv的跟踪算法
tracker = cv2.TrackerCSRT_create()
# initialize the bounding box coordinates of the object we are going to track
initBB = None
# 初始化视频,打开视频文件
if not args.get("video", False):
print("[INFO] starting video stream...")
vs = VideoStream(src=0).start()
time.sleep(1.0)
# 没有视频文件,则开启摄像头
else:
vs = cv2.VideoCapture(args["video"])
# 初始化FPS,用于评估
fps = None
# 主循环代码
while True:
frame = vs.read()
frame = frame[1] if args.get("video", False) else frame
if frame is None:
break
frame = imutils.resize(frame, width=800)
(H, W) = frame.shape[:2]
# check to see if we are currently tracking an object
if initBB is not None:
# grab the new bounding box coordinates of the object
(success, box) = tracker.update(frame)
# check to see if the tracking was a success
if success:
(x, y, w, h) = [int(v) for v in box]
# 按照voc格式文件生成imgs目录和anns目录
if not os.path.exists("imgs"):
os.makedirs("imgs")
if not os.path.exists("anns"):
os.makedirs("anns")
# 图片保存到imgs目录
filename = str(time.time()) + ".jpg"
# 标注文件生成,保存到anns目录
anno = GEN_Annotations(filename)
anno.set_size(W, H, 3)
anno.add_pic_attr(args["name"], x, y, x + w, y + h)
cv2.imwrite("imgs/" + filename, frame)
anno.savefile("anns/" + os.path.splitext(filename)[0] + ".xml")
cv2.rectangle(frame, (x, y), (x + w, y + h),
(0, 255, 0), 2)
# 更新FPS计数器
fps.update()
fps.stop()
# 显示字幕信息
info = [
("Lable", args["name"]),
("Tracker", args["tracker"]),
("Success", "Yes" if success else "No"),
("FPS", "{:.2f}".format(fps.fps())),
]
for (i, (k, v)) in enumerate(info):
text = "{}: {}".format(k, v)
cv2.putText(frame, text, (10, H - ((i * 20) + 20)),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
# show the output frame
cv2.imshow("Frame", frame)
key = cv2.waitKey(1) & 0xFF
# if the 's' key is selected, we are going to "select" a bounding
# box to track
if key == ord("s"):
# select the bounding box of the object we want to track (make
# sure you press ENTER or SPACE after selecting the ROI)
initBB = cv2.selectROI("Frame", frame, fromCenter=False,
showCrosshair=True)
# start OpenCV object tracker using the supplied bounding box
# coordinates, then start the FPS throughput estimator as well
tracker.init(frame, initBB)
fps = FPS().start()
# if the `q` key was pressed, break from the loop
elif key == ord("q"):
break
# if we are using a webcam, release the pointer
if not args.get("video", False):
vs.stop()
# otherwise, release the file pointer
else:
vs.release()
# close all windows
cv2.destroyAllWindows()