物体检测之加载数据集和画框

2022-09-19  本文已影响0人  小黄不头秃

(一)物体检测

前面咱们讨论的都是图片分类的问题,他注重的是图面中的主体,而对于其他的物体,就不会去关注。那么如果画面中有一只狗和一只猫,我们的模型该如何进行分类呢?其实我们更希望他能够做到的是,能发现图里面有一只狗和一只猫并且能够知道它们的位置,这就是物体检测。

(1)边缘框

在目标检测中,我们通常使用边界框(bounding box)来描述对象的空间位置。
边界框是矩形的,由矩形左上角的以及右下角的xy坐标决定。
另一种常用的边界框表示方法是边界框中心的(x, y)轴坐标以及框的宽度和高度。

有两种写法可以将一个物体框出来,

(2)目标检测数据集

不能和以前一样一个文件夹里面放一类图片,我们现在可能需要一个单独的文件用来存储图片的标签。例如:(图片名称,物体类别,边缘框)

COCO数据集,一共有80类物体,330K的图片,1.5M个物体。

(二)代码实现

画一个框和数据集
后面的数据集可能会用到一个香蕉集
下载地址:http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip

%matplotlib inline
import torch
from d2l import torch as d2l
import numpy as np
import matplotlib.pyplot as plt

d2l.set_figsize()
img = d2l.plt.imread('../img/catdog.jpg')
d2l.plt.imshow(img)
x = torch.arange(5)
x = (x,x,x)
print(torch.stack(x, axis=-1))
print(torch.stack(x, axis=0))
def box_corner_to_center(boxes):
    """从(左上,右下)转换到(中间,宽高)"""
    x1,y1,x2,y2 = boxes[:,0],boxes[:,1],boxes[:,2],boxes[:,3]
    cx = (x1+x2)/2
    cy = (y1+y2)/2
    w = x2 - x1
    h = y2 - y1
    boxes = torch.stack((cx, cy, w, h), axis=-1)
    return boxes

def box_center_to_corner(boxes):
    cx,cy,w,h = boxes[:,0],boxes[:,1],boxes[:,2],boxes[:,3]
    x1 = cx-0.5*w
    x2 = cx+0.5*w
    y1 = cy-0.5*h
    y2 = cy+0.5*h
    boxes = torch.stack((x1, y1, x2, y2), axis=-1)
    return boxes
dog_bbox, cat_bbox = [60.0, 45.0, 378.0, 516.0], [400.0, 112.0, 655.0, 493.0]
boxes = torch.tensor((dog_bbox, cat_bbox))
box_center_to_corner(box_corner_to_center(boxes)) == boxes
def bbox_to_rect(bbox, color):
    return d2l.plt.Rectangle(
        xy=(bbox[0],bbox[1]),
        width=bbox[2]-bbox[0],
        height=bbox[3]-bbox[1],
        fill=False,
        edgecolor= color,
        linewidth=2
    )

fig = d2l.plt.imshow(img)
fig.axes.add_patch(bbox_to_rect(dog_bbox,"blue"))
fig.axes.add_patch(bbox_to_rect(cat_bbox,"red"))

# 目标检测数据集
# 这个数据集叫香蕉集,用来检测香蕉
# 可以手动下载,也可以使用代码下载
# 下载地址:http://d2l-data.s3-accelerate.amazonaws.com/banana-detection.zip
%matplotlib inline
import os 
import pandas as pd
import torch
import torchvision
from d2l import torch as d2l
from PIL import Image
#@save
d2l.DATA_HUB['banana-detection'] = (
    d2l.DATA_URL + 'banana-detection.zip',
    '5de26c8fce5ccdea9f91267273464dc968d20d72')
def read_data_bananas(is_train=True):
    """读取香蕉检测数据集中的图像和标签"""
    data_dir = "../data/banana-detection/"
    csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
                             else 'bananas_val', 'label.csv')
    csv_data = pd.read_csv(csv_fname)
    csv_data = csv_data.set_index('img_name')
    images, targets = [], []
    for img_name, target in csv_data.iterrows():
        images.append(torchvision.io.read_image(
            os.path.join(data_dir, 'bananas_train' if is_train else
                         'bananas_val', 'images', f'{img_name}')))
        # 这里的target包含(类别,左上角x,左上角y,右下角x,右下角y),
        # 其中所有图像都具有相同的香蕉类(索引为0)
        targets.append(list(target))
    # print(type(images[0]))
    # print(type(targets[0]))
    return images, torch.tensor(targets).unsqueeze(1) / 256
class BananasDataset(torch.utils.data.Dataset):
    """一个用于加载香蕉检测数据集的自定义数据集"""
    def __init__(self, is_train):
        self.features, self.labels = read_data_bananas(is_train)
        print('read ' + str(len(self.features)) + (f' training examples' if
              is_train else f' validation examples'))

    def __getitem__(self, idx):
        return (self.features[idx].float(), self.labels[idx])

    def __len__(self):
        return len(self.features)
def load_data_bananas(batch_size):
    """加载香蕉检测数据集"""
    train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
                                             batch_size, shuffle=True)
    val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
                                           batch_size)
    return train_iter, val_iter
batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))
batch[0].shape, batch[1].shape
# 把通道数移到后面去
imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
for ax, label in zip(axes, batch[1][0:10]):
    d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])

以上都是书本上的写法,我一开始觉得还挺繁琐,于是自己又重新写了一下。结果发现还得是上面的这种写法效率高。

# 简单的写法, 反面教材
def read_csv(train=True):
    base_path = "../data/banana-detection/"
    if train: path = base_path + "bananas_train/label.csv"
    else: path = base_path + "bananas_val/label.csv"
    file = pd.read_csv(path)
    train_lable =file.set_index("img_name")
    features = []
    label = []
    for img_name, target in train_lable.iterrows():
        if train: features.append(torchvision.io.read_image(base_path+"bananas_train/images/"+img_name))
        else: features.append(torchvision.io.read_image(base_path+"bananas_val/images/"+img_name))
        label.append(list(target))
    # 所有时间都花在下面这个转换了,消耗的时间太多了,不推荐使用
    # 我尝试了使用其他的方法例如PIL的Image.open,结果是会消耗更多的时间
    # 我认为书本中的写法快的原因是重写了__len__()方法
    features = [item.numpy() for item in features]
    return (torch.tensor(features),torch.tensor(label).unsqueeze(1) / 256)

def load_bananas(batch_size=32):
    train_data = read_csv(True)
    test_data = read_csv(False)
    train_dataset = torch.utils.data.TensorDataset(*train_data)
    test_dataset = torch.utils.data.TensorDataset(*test_data)
    return torch.utils.data.DataLoader(train_dataset,shuffle=True,batch_size=batch_size),torch.utils.data.DataLoader(test_dataset,shuffle=True,batch_size=batch_size)

train_iter,test_iter = load_bananas()
from PIL.ImageDraw import Draw as draw
from PIL import Image
batch = next(iter(train_iter))
# 这里的permute和reshape并不一样,参数列表是矩阵的下标
# 可以理解为将原来的(c,h,w)即(0,1,2)转变为了(h,w,c)即(1,2,0)
imgs = batch[0][0].permute(1,2,0)

fig = plt.imshow(imgs)
print(batch[1][0][0])
fig.axes.add_patch(bbox_to_rect((batch[1][0][0][1:5]*256),color="r"))

上一篇下一篇

猜你喜欢

热点阅读