SSD data_augmentation对自己数据集做增强

2019-08-22  本文已影响0人  miahuang

前言:

目前我在做车辆目标检测任务,虽然对实时性的要求不高,但是对检测的准确性有比较高的要求.使用yolo ,retinanet 神经网络进行检测的时候发现, 喂数据的多少,很影响检测的结果.不论是做什么任务,数据一直都是一个比较头痛的问题. ssd是一个优秀的网络模型.在数据增强方法做了很多处理,例如裁剪,明亮强度等.我在github上面,找到了ssd源码,https://github.com/amdegroot/ssd.pytorch, 理解数据增强代码逻辑,并对自己的数据集进行增强处理.

augmentation.py 源码理解

源码的文件结构是很清晰的. 路径utils/augmentation.py就是数据增强的代码.这份源码每个函数做的工作其实看函数名就很清楚. 例如class RandomBrightness 就是对图片随机增加亮度.

class RandomBrightness(object):
    def __init__(self, delta=32):
        assert delta >= 0.0
        assert delta <= 255.0
        self.delta = delta
    def __call__(self, image, boxes=None, labels=None):
        if random.randint(2):
            delta = random.uniform(-self.delta, self.delta)
            image += delta
        return image, boxes, labels

我想重点是这个class SSDAugmentation

class SSDAugmentation(object):
    def __init__(self, size=300, mean=(104, 117, 123)):
        self.mean = mean
        self.size = size
        self.augment = Compose([
            ConvertFromInts(),
            ToAbsoluteCoords(),
            PhotometricDistort(),
            Expand(self.mean),
            RandomSampleCrop(),
            RandomMirror(),
            ToPercentCoords(),
            Resize(self.size),
            SubtractMeans(self.mean)
        ])
    def __call__(self, img, boxes, labels):

        return self.augment(img, boxes, labels)

不论在上面的 class RandomBrightness ,SSDAugmentation,还是文件中的其他 class ,都定义了一个call 方法,我查了这个函数的使用方法,发现是python 的魔法方法,作用是让类的对象也能够作为一个函数被调用.理解了这个魔法方法,我才发现这份代码真的是写的很美,很值得我去学习. 赞叹完了,还有另外一个地方要注意, 那就是self.augmentation = Compose([.....]) , 成员变量 augmentation 是Compose的一个对象,所以调用 SSDAugmentation 的call方法时候,就会执行 Compose 类的call 方法,对图像进行一系列的数据处理.

数据格式转换

在ssd的源码中,支持voc,coco的数据格式,因此我也把自己的数据集提前转成voc的格式.ssd 读取voc数据的代码在data/voc0712.py, 从class VOCDetection这个类开始阅读,便可以知道整个处理流程.

class VOCDetection(data.Dataset):
    """VOC Detection Dataset Object

    input is image, target is annotation

    Arguments:
        root (string): filepath to VOCdevkit folder.
        image_set (string): imageset to use (eg. 'train', 'val', 'test')
        transform (callable, optional): transformation to perform on the
            input image
        target_transform (callable, optional): transformation to perform on the
            target `annotation`
            (eg: take in caption string, return tensor of word indices)
        dataset_name (string, optional): which dataset to load
            (default: 'VOC2007')
    """

    def __init__(self, root,
                 # image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
                 image_sets=[('2007', 'trainval')],
                 transform=None, target_transform=VOCAnnotationTransform(),
                 dataset_name='VOC0712'):
        self.root = root
        self.image_set = image_sets
        self.transform = transform
        self.target_transform = target_transform
        self.name = dataset_name
        self._annopath = osp.join('%s', 'Annotations', '%s.xml')
        self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
        self.ids = list()
        for (year, name) in image_sets:
            rootpath = osp.join(self.root, 'VOC' + year)

            for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
                self.ids.append((rootpath, line.strip()))
        for i in range( len(self.ids)):
            self.pull_item(i)

编写自己的脚本

逻辑理解后,代码实现和结果见下.

from utils.augmentations import  SSDAugmentation
from data import  myvoc0712 as myvoc
from data import  config
from scipy import misc
import cv2
import random
import numpy as np
import argparse
parser = argparse.ArgumentParser(
    description='Single Shot MultiBox Detector Training With Pytorch')
train_set = parser.add_mutually_exclusive_group()
parser.add_argument('--dataset', default='VOC', choices=['VOC', 'COCO'],
                    type=str, help='VOC or COCO')
parser.add_argument('--dataset_root', default=myvoc.VOC_ROOT,
                    help='Dataset root directory path')
args = parser.parse_args()
cfg = config.voc
aug = SSDAugmentation(cfg['min_dim'],config.MEANS)
dataset = myvoc.VOCDetection(root=args.dataset_root,
                       transform=SSDAugmentation(cfg['min_dim'],
                                                 config.MEANS))
ssd.png ssd2.png
上一篇下一篇

猜你喜欢

热点阅读