图像预处理-随机贴图生成标注文件的python实现

2023-02-08  本文已影响0人  智驱力AI

1. 任务目标

在训练目标检测模型时,若数据存在以下情况:图像之间差异小、不同类别数目差异大、有些目标物体的样本图片难以搜集等,需要对数据进行处理。本文以fire类别为例实现对数据的随机贴图增广,生成新的标注文件,或在已有标注文件中添加,且避免覆盖已有标注。

2. Python实现

2.1 将已标注的目标保存

数据存储格式:(路径中不要包含中文)

输入文件夹:

输出文件夹:

代码:

import os
import cv2
import time
import argparse
import xml.etree.ElementTree as ET

from tqdm import tqdm

parser = argparse.ArgumentParser(description='Read box from xml and crop from image.')
parser.add_argument('--dst-label', default='fire', help='label box to cut')
parser.add_argument('--input-path', default='data/fire_dataset', help='contain Annotations, JPEGImages folder')
parser.add_argument('--output-path', default='data/fire_cut', help='output path')
args = parser.parse_args()


def read_xml_box(xml_file):
    xml_anno = ET.parse(xml_file)
    result = []
    for obj in xml_anno.findall('object'):
        class_name = obj.find('name').text.strip()
        xmin = obj.find('bndbox').find('xmin').text
        xmax = obj.find('bndbox').find('xmax').text
        ymin = obj.find('bndbox').find('ymin').text
        ymax = obj.find('bndbox').find('ymax').text
        result.append([class_name, int(xmin), int(xmax), int(ymin), int(ymax)])
    return result


def main():
    xml_path = os.path.join(args.input_path, "Annotations")
    img_path = os.path.join(args.input_path, "JPEGImages")
    for img_name in tqdm(os.listdir(img_path)):
        xml_name = '{}.xml'.format(img_name.rsplit('.', maxsplit=1)[0])
        xml_file = os.path.join(xml_path, xml_name)
        if not os.path.exists(xml_file):
            print('{} not exists'.format(xml_name))
            continue
        img = cv2.imread(os.path.join(img_path, img_name))
        cls_boxes = read_xml_box(xml_file)  
        for cls_box in cls_boxes:
            class_name, xmin, xmax, ymin, ymax = cls_box
            if class_name != args.dst_label:
                continue
            crop_img = img[ymin:ymax, xmin:xmax, :]
            new_name = '{}_{}.jpg'.format(xml_name[:-4], str(time.time()).replace('.', ''))
            save_path = os.path.join(args.output_path, class_name)
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            save_name = os.path.join(save_path, new_name)    
            print(save_name)
            cv2.imwrite(save_name, crop_img)

if __name__ == '__main__':
    main()

生成结果:

2.2 随机贴图扩充数据

数据存储格式:

输入文件夹:

注意:此处作为背景的数据,可以只有图片,没有标注文件。也可以既有图片又有标注文件。若没有标注文件,则生成;若有标注文件,则在贴图时会避免目标框的遮挡。

输出文件夹:

输入参数:

代码:

import os
import time
import random
import argparse
import xml.etree.ElementTree as ET

from PIL import Image

parser = argparse.ArgumentParser()
parser.add_argument('--tietudir', default='data/fire_cut/fire', help='贴图路径') # 
parser.add_argument('--xml_path', default='data/fire_bg/Annotations', help='躲避路径')
parser.add_argument('--img_path', default='data/fire_bg/JPEGImages', help='底图路径')
parser.add_argument('--save_path', default='data/output', help='保存路径')
parser.add_argument('--gen_num', default=11, help='保存个数')
parser.add_argument('--cls_name', default='fire', help='目标类别名')
args = parser.parse_args()


def creat_src_xml(width_ditu,height_ditu,box,save_path, name, cls_name):
    xml_name = name
    xml_file = save_path + '/' + xml_name
    x = open(xml_file, 'w')
    x.writelines('<annotation>\n')
    x.writelines('    <folder>data</folder>\n')
    x.writelines('    <filename>' + xml_name + '</filename>\n')
    x.writelines('    <path>'  + xml_file + '</path>\n')
    x.writelines('    <source>\n')
    x.writelines('        <database>Unknown</database>\n')
    x.writelines('    </source>\n')
    x.writelines('    <size>\n')
    x.writelines('        <width>' + str(width_ditu) + '</width>\n')
    x.writelines('        <height>' + str(height_ditu) + '</height>\n')
    x.writelines('        <depth>3</depth>\n')
    x.writelines('    </size>\n')
    x.writelines('    <segmented>0</segmented>\n')
    x.writelines('    <object>\n')
    x.writelines('        <name>' + cls_name + '</name>\n')
    x.writelines('        <pose>Unspecified</pose>\n')
    x.writelines('        <truncated>0</truncated>\n')
    x.writelines('        <difficult>0</difficult>\n')
    x.writelines('        <bndbox>\n')
    x.writelines('            <xmin>' + str(int(box[0])) + '</xmin>\n')
    x.writelines('            <ymin>' + str(int(box[1])) + '</ymin>\n')
    x.writelines('            <xmax>' + str(int(box[2])) + '</xmax>\n')
    x.writelines('            <ymax>' + str(int(box[3])) + '</ymax>\n')
    x.writelines('        </bndbox>\n')
    x.writelines('    </object>\n')    
    x.writelines('</annotation>\n')
      x.close()

def creat_xml(box, save_path, copy_path, cls_name):
    readFile = open(copy_path, encoding='UTF-8')
    lines = readFile.readlines()
    readFile.close()
    x = open(save_path, 'w', encoding='UTF-8')
    x.writelines([item for item in lines[:-1]])
    x.writelines('    <object>\n')
    x.writelines('        <name>' + cls_name + '</name>\n')
    x.writelines('        <pose>Unspecified</pose>\n')
    x.writelines('        <truncated>0</truncated>\n')
    x.writelines('        <difficult>0</difficult>\n')
    x.writelines('        <bndbox>\n')
    x.writelines('            <xmin>' + str(int(box[0])) + '</xmin>\n')
    x.writelines('            <ymin>' + str(int(box[1])) + '</ymin>\n')
    x.writelines('            <xmax>' + str(int(box[2])) + '</xmax>\n')
    x.writelines('            <ymax>' + str(int(box[3])) + '</ymax>\n')
    x.writelines('        </bndbox>\n')
    x.writelines('    </object>\n')
    x.writelines('</annotation>\n')
      x.close()

def read_xml_box(xml_file):
    xml_anno = ET.parse(xml_file)
    result = []
    for obj in xml_anno.findall('object'):
        class_name = obj.find('name').text.strip()
        xmin = obj.find('bndbox').find('xmin').text
        xmax = obj.find('bndbox').find('xmax').text
        ymin = obj.find('bndbox').find('ymin').text
        ymax = obj.find('bndbox').find('ymax').text
        result.append([class_name, int(xmin), int(ymin), int(xmax), int(ymax)])
      return result

def compute_IOU(rec1, rec2):
    left_column_max = max(rec1[0], rec2[0])
    right_column_min = min(rec1[2], rec2[2])
    up_row_max = max(rec1[1], rec2[1])
    down_row_min = min(rec1[3], rec2[3])
    if left_column_max >= right_column_min or down_row_min <= up_row_max:
        return 0
    else:
        s1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
        s2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
        s_cross = (down_row_min - up_row_max) * (right_column_min - left_column_max)
        return s_cross / (s1 + s2 - s_cross)

def random_box(end1, end2, end3, end4):
    xmin = random.randint(0, end1)
    ymin = random.randint(0, end2)
    xmax = xmin + end3
    ymax = ymin + end4
      return xmin, ymin, xmax, ymax

def get_shuffle_list(img_path, gen_num):
    imgs = os.listdir(img_path)
    num = len(imgs)
    if not num:
        return None
    random.shuffle(imgs)
    times, remainder = divmod(gen_num, num)
    name_gen = imgs[:remainder]
    for i in range(times):
        name_gen.extend(imgs)
    random.shuffle(name_gen)
      return name_gen

def process(gen_num, tietudir, img_path, xml_path, save_path, cls_name):
    os.makedirs(save_path, exist_ok=True)
    # load fg
    fg_gen = get_shuffle_list(tietudir, gen_num)
    # load bg
    bg_gen = get_shuffle_list(img_path, gen_num)
    if fg_gen is None or bg_gen is None:
        return
    # combine
    num = 0
    for fg, bg in zip(fg_gen, bg_gen):
        num += 1
        fg_img = Image.open(os.path.join(tietudir, fg))
        bg_img = Image.open(os.path.join(img_path, bg))
        save_name = bg.rsplit('.', maxsplit=1)[0]
        bg_xml_name = '{}.xml'.format(save_name)
        bg_xml_path = os.path.join(xml_path, bg_xml_name)
        cls_boxes = []
        if os.path.exists(bg_xml_path): 
            cls_boxes = read_xml_box(bg_xml_path)
    
        fg_width = fg_img.size[0]  # 贴图长宽
        fg_height = fg_img.size[1]
        bg_width = bg_img.size[0]  # 底图长宽
        bg_height = bg_img.size[1]
        scale = 1.0
        try:
            box2 = random_box(bg_width - fg_width, bg_height - fg_height, fg_width, fg_height)
        except:
            scale = 0.5
            fg_img.resize((int(fg_width*scale), int(fg_height*scale)))
            fg_width = fg_img.size[0] 
            fg_height = fg_img.size[1]
            if fg_width > bg_width or fg_height > bg_height:
                continue
            box2 = random_box(bg_width - fg_width, bg_height - fg_height, fg_width, fg_height)

        timer = 50
        loop_flag = True
        while timer and loop_flag and cls_boxes:
            timer -= 1
            loop_flag = False
            for box in cls_boxes:
                area = compute_IOU(box[1:5], box2)
                if area > 0:
                    box2 = random_box(bg_width - fg_width, bg_height - fg_height, fg_width, fg_height)
                    loop_flag = True
                    break
        if timer > 0:
            bg_img.paste(fg_img, (box2[0], box2[1]))
            bg_img_add = bg_img.convert('RGB')
            new_name = '{}_{}'.format(save_name, str(time.time()).replace('.', ''))
            bg_img_add.save(os.path.join(save_path, new_name + '.jpg'))
            xml_save_path = os.path.join(save_path, new_name + '.xml')   
            
            if not cls_boxes:
                creat_src_xml(bg_width, bg_height, box2, save_path, new_name+'.xml', cls_name)
            else:
                creat_xml(box2, xml_save_path, bg_xml_path, cls_name)
  
    
if __name__ == "__main__":
    process(args.gen_num, 
            args.tietudir, 
            args.img_path, 
            args.xml_path, 
            args.save_path,
            args.cls_name)

生成结果:(labelimg查看)

智驱力-科技驱动生产力

上一篇下一篇

猜你喜欢

热点阅读