目标检测

SSD+pytorch+训练爬过的坑

2019-12-03  本文已影响0人  乔大叶_803e

SSD训练自己的数据集

我是参照他的步骤对Git上pytorch进行改动的

git SSD

这是Git上的网址

具体SSD的原理这个文章就先不做介绍了。

具体就是介绍下训练过程中爬过的坑,我是按照他的要求进行改的,但是比较坑的是他没有具体介绍如果是自己的数据集的话,是应该怎么进行更改数据集的位置,以及路径的问题。

查了网上很多的介绍,但是基本都没有介绍的,所以不得不,自己照着源码进行看。

也是希望可以养成看源码的习惯,毕竟问题是多种多样的,从源码出发一定是可以进行解决的。

因为选择的是VOC的数据集,所以需要把自己标好的VOC数据集放在对应的数据集目录下。

在train.py 中

parser = argparse.ArgumentParser(
    description='Single Shot MultiBox Detector Training With Pytorch')
train_set = parser.add_mutually_exclusive_group()
parser.add_argument('--dataset', default='VOC', choices=['VOC', 'COCO'],
                    type=str, help='VOC or COCO')
parser.add_argument('--dataset_root', default=VOC_ROOT,
                    help='Dataset root directory path')
parser.add_argument('--basenet', default='vgg16_reducedfc.pth',
                    help='Pretrained base model')

可以看到是默认的是使用的是VOC数据集,但是其实是默认的使用的VOC2007 以及 VOC2012的数据集。

我把VOCdevkit放在data文件下

但是一直给我报的是找不到data/coco/coco_labels.txt

我开始不清楚,原来是因为作者是用的coco数据集做的训练所有在初始化的过程中,就要导入这些txt文件。

__init__.py文件下

from .voc0712 import VOCDetection, VOCAnnotationTransform, VOC_CLASSES, VOC_ROOT

from .coco import COCODetection, COCOAnnotationTransform, COCO_CLASSES, COCO_ROOT, get_label_map
from .config import *
import torch
import cv2
import numpy as np

发现他在初始化就引入了 coco.py的文件

接下来我们来看
coco.py文件

from .config import HOME
import os
import os.path as osp
import sys
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import cv2
import numpy as np

COCO_ROOT = osp.join(HOME, 'data/coco/')
IMAGES = 'images'
ANNOTATIONS = 'annotations'
COCO_API = 'PythonAPI'
INSTANCES_SET = 'instances_{}.json'

这里有个COCO_ROOT写的是相对路径,就是没引到我改引入的地方,我就把这个路径直接写死了

改成我的路径了,读者可以根据自己情况进行修改。
我是直接修改成了 data 文件下了

同理把voc0712.py的文件
路径改为我的路径
/data/VOCdevkit

改好后当执行train的时候发现还是报错

原来是
由于Pytorch版本不同,较新版的代码直接运行会报错,需要修改部分代码,主要是将.data[0]的部分改成.item()

修改train.py
修改源码183.184两行

loc_loss += loss_l.data[0]
conf_loss += loss_c.data[0]

改为:

loc_loss += loss_l.item()
conf_loss += loss_c.item()

修改源码188行

print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data[0]), end=' ')

改为:

print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.item()), end=' ')

修改源码165行

images, targets = next(batch_iterator)

改为:

try:
    images, targets = next(batch_iterator)
except StopIteration as e:
    batch_iterator = iter(data_loader)
    images, targets = next(batch_iterator)

修改mutibox_loss.py

修改:源码包/layers/modules/mutibox_loss.py 调换第97,98行:

loss_c = loss_c.view(num, -1)
loss_c[pos] = 0 # filter out pos boxes for now

修改第114行为:

N = num_pos.data.sum().double()
loss_l /= N
loss_c /= N

改完后总是提示的是我的
train.py 中的 images 以及 targets 错误

把具体错误查了下,说是局部变量不能再全局进行引入

train.py

# load train data
        images, targets = next(batch_iterator)

        if args.cuda:
            images = Variable(images.cuda())
            targets = [Variable(ann.cuda(), volatile=True) for ann in targets]
        else:
            images = Variable(images)
            targets = [Variable(ann, volatile=True) for ann in targets]
        # forward
        t0 = time.time()
        out = net(images)
        # backprop
        optimizer.zero_grad()
        loss_l, loss_c = criterion(out, targets)

改为:

# load train data
        images, targets = next(batch_iterator)

        if args.cuda:
            images = Variable(images.cuda())
            targets = [Variable(ann.cuda(), volatile=True) for ann in targets]
        else:
            image = Variable(images)#去掉s
            target = [Variable(ann, volatile=True) for ann in targets] #去掉s
        # forward
        t0 = time.time()
        out = net(image)#去掉s
        # backprop
        optimizer.zero_grad()
        loss_l, loss_c = criterion(out, target)#去掉s

现在就剩一个错误了🤣🤣🤣🤣🤣

执行train.py 时候

RuntimeError: copy_if failed to synchronize: device-side assert triggered

问题描述:

这个问题是我在使用SSD做目标检测时遇到的,我要检测的目标有5种类别,所以我在data/config.py中的num_classes参数写了6,经过多方查找,发现了一个没注意到的细节,类别应该是6+1,那个1应该是背景。

还有一个原因就是标签的标号没有从0开始。

把class改为7后就没事了 哈哈

开始训练了

训练的时候发现他训练时候对GPU的利用也不是连续的,一会很高,一会是0 这很让人琢磨不透,我开始以为是自己没用上cuda呢

终于开始训练了,但是出现了我loss是nan错误,难受啊

下一讲就要开始解决这个问题了 。

上一篇下一篇

猜你喜欢

热点阅读