FastAI

FastAI06-图像分割问题

2019-10-10  本文已影响0人  科技老丁哥

图像分割本质上就是对图像中每个像素点进行分类。

下面是我学习lesson3-camvid的笔记

1. 准备数据集

定义根目录,查看该目录下有哪些文件:

path = Path(r'E:\PyProjects\DataSet\Camvid\camvid')
path.ls()

[WindowsPath('E:/PyProjects/DataSet/Camvid/camvid/codes.txt'),
WindowsPath('E:/PyProjects/DataSet/Camvid/camvid/images'),
WindowsPath('E:/PyProjects/DataSet/Camvid/camvid/labels'),
WindowsPath('E:/PyProjects/DataSet/Camvid/camvid/valid.txt')]

获取图片的路径组成list,获取label的路径组成list,查看前面的几个路径:

fnames = get_image_files(path_img)
fnames[:3]

[WindowsPath('E:/PyProjects/DataSet/Camvid/camvid/images/0001TP_006690.png'),
WindowsPath('E:/PyProjects/DataSet/Camvid/camvid/images/0001TP_006720.png'),
WindowsPath('E:/PyProjects/DataSet/Camvid/camvid/images/0001TP_006750.png')]

1.1 显示图片及其对应的label

显示某张原始图片,fnames是否正常:

img_f = fnames[0]
img = open_image(img_f)
img.show(figsize=(5,5))

获取该图像的label,并显示出来,这个和图像分类问题不一样,图像分类问题每一张图片对应一个或多个label,但分割问题的每一个像素都对应一个label,可以看成该图片的mask,显示出来时要用特殊的函数open_mask.

get_y_fn = lambda x: path_lbl/f'{x.stem}_P{x.suffix}'
mask = open_mask(get_y_fn(img_f))
mask.show(figsize=(5,5), alpha=1)

上面的get_y_fn函数表示从path_lbl目录中寻找和img_f同名但中间有_P的文件,用这个文件作为label。
所以这个label文件本质上也是一个图片,但是每个像素的值并不是RGB三元素,而是一个int值,该值就是该像素的标签。
打印出来可以看出:


1.2 制备专用的databunch

fastAI需要将所有的图像和标签组装成专用的databunch,所以此处组装方法为:

src = (SegmentationItemList.from_folder(path_img)
       .split_by_fname_file('../valid.txt')
       .label_from_func(get_y_fn, classes=codes))
data = (src.transform(get_transforms(), size=size, tfm_y=True)
        .databunch(bs=bs)
        .normalize(imagenet_stats))

对于图像分割问题,使用专用的SegmentationItemList类,且从path_img这个源图像目录中加载图像,使用get_y_fn函数来加载图像对应的label。

由于原图像是从视频中截取的,所以很多图像之间非常相似,不能使用随机分割的方式得到train set 和val set,所以用官方给定的valid.txt中的图像作为val set。

在对原图像进行增强时,由于get_transform中有对图像进flip的操作,所以对应的label也要进行flip,此时要加上tfm_y=True。

1.3 查看databunch是否正常

可以用show_batch来显示出准备的databunch批次图片,看看是否正常

data.show_batch(2, figsize=(10,7))

show_batch默认显示的时train set中的批次图片。如果要显示val set中的批次图片,可以用:

data.show_batch(2, figsize=(10,7), ds_type=DatasetType.Valid)

由此可以看出,databunch准备得没有问题

2. 模型训练

因为kaggle竞赛要求排除掉void_code,所以此处并没有用传统的accuracy,而是自定义的准确率指标:

def acc_camvid(input, target):
    target = target.squeeze(1) # 删除第1个维度
    mask = target != void_code # 需要排除void code
    return (input.argmax(dim=1)[mask]==target[mask]).float().mean()

其含义是:一张图片中预测准确的像素点的个数占所有像素点的比例,而所有图片的这个比例求平均,就是准确率指标。

模型结构的定义为:

learn = unet_learner(data, models.resnet34, metrics=metrics, wd=wd)

对于图像分割问题,使用U-net得到的准确率要比CNN模型要好得多,所以并没有使用传统的resnet等CNN模型,U-net模型的结构图像为:


训练之后的结果为:


显示ground truth和predictions:


可以看出,在简单的10个epoch之后,就达到了0.90左右的acc.

上一篇下一篇

猜你喜欢

热点阅读