图像分割

图像语义分割实践(二)数据增强与读取

2022-04-23  本文已影响0人  智能之心

图像语义分割实践(二)数据增强与读取

Pytorch数据加载顺序

神经网络模型训练过程需要进行梯度更新,梯度更新可分三种方式。1.批梯度下降(batch gradient descent):一次所有数据批计算,过于复杂,计算缓慢;2.随机梯度下降(stochastic gradient descent):每次读一个数据,数据差异大,导致训练波动太大,收敛性不好;3.最小批量梯度下降(mini-batch gradient descent / SGD gradient descent):随机取一定量数据进行训练,既降低计算量,又能提高训练速度。

使用pytorch对数据进行批次量读取构建,首先了解其加载数据顺序分为以下三个点。

pytorch中加载数据的顺序分为以下三个点:
1."创建一个 dataset 对象"; 并加入 transforms 数据增强方案;
2."创建一个 dataloader 对象";
3."获取数据集的 mini_batch"; 循环 dataloader 对象, 获取训练样本送入模型进行训练;

其中, 
"1.创建一个 dataset 对象", 继承 pytorch 的 torch.utils.data.Dataset; 一般需要含3个主要函数:
    1.__init__:    传入数据, 或者直接加载固化的数据包;
    2.__len__:     返回这个数据集一共有多少个item;
    3.__getitem__:  返回一条训练数据, 并将其转换成tensor;

"2.创建一个 dataloader 对象", 采用 pytorch 的 torch.utils.data.DataLoader 整合成 mini_batch;

"3.获取数据集的 mini_batch"

Pytorch官方示例与实践改造

Pytorch官方示例与实践改造

1.构建dataset对象.png 2.构建dataloader对象.png 3.索引minibatch数据.png

数据加载万能模板

针对自己数据集进行分装,数据列表单元+数据增强单元是我们需要关注的点,所以只要在这两个函数进行改造,其他部分和官方的1.dataset对象,2.dataloader对象,3.mini_batch获取一致。


4.minibatch可视化.png

模板代码示例

######## py内置函数:help-文件架构, dir-代码架构 ########
import torch # 包含基本,加减乘除,张量操作,优化器'torch.optim', 数据索引 'torch.utils.data.DataLoader'
import torch.nn as nn # "类":   包含卷积,池化,激活,损失等 "nn.CrossEntropyLoss()"
import torch.nn.functional as F  # "函数": 包含卷积,池化,激活,损失等 "F.cross_entropy()"
import torchvision # 包含图像算法的基本操作等 torchvision.models; torchvision.datasets;
import torchvision.transforms as T # "类":   包含图像增强方向等 "T.RandomCrop()"
import torchvision.transforms.functional as TF # "函数": 包含图像增强方向等 "TF.center_crop()"
import os
import glob
import math
import numpy as np
import random
from PIL import Image
import PIL
import matplotlib.pyplot as plt


#################### 构建 lines 可略 ####################
class MyLinesGetter(object):
    def __init__(self, FilePath, dtype="seg"):
        self.FilePath = FilePath
        self.dtype = dtype # None="cls", "seg"
    def getter(self):
        self.datalines = []
        with open(self.FilePath, "r") as f:
            lines = f.read().splitlines()
            if self.dtype is 'seg':
                for line in lines:
                    img_dir, seg_dir = line.split(" ")[:2]
                    img_dir = os.path.join("data_flowers", "JPEGImages", img_dir)
                    seg_dir = os.path.join("data_flowers", "SegmentationClassRAW", seg_dir)
                    self.datalines.append([img_dir, seg_dir])
            else:
                raise "wrong dtype! check dtype on ['seg']!"
        return self.datalines

#################### 创建 dataset class ####################
class SegmentDataset(torch.utils.data.Dataset): # 继承
    def __init__(self, dataset, transforms=None):
        self.dataset = dataset
        self.transforms = transforms
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        img_dir, seg_dir = self.dataset[idx]
        img = Image.open(img_dir)
        seg = Image.open(seg_dir)
        if self.transforms is not None:
            data_dict = self.transforms(img, seg)
            img = data_dict['image']
            seg = data_dict["mask"]
        else:
            img = TF.to_tensor(img)
            seg = torch.as_tensor(np.array(seg), dtype=torch.int64)
        return img, seg
    pass

#################### 创建 transforms+Compose 增强方案 ####################
class Resize(object):
    def __init__(self, size):
        self.size = size
    
    def __call__(self, image, target=None, label=None):
        image = TF.resize(image, self.size)
        if target is not None:
            target = TF.resize(target, self.size, interpolation=PIL.Image.BILINEAR) # PIL.Image.BILINEAR
        if label is not None:
            label = label
        return image, target, label
    pass

class ToTensor(object):
    def __call__(self, image, target=None, label=None):
        image = TF.to_tensor(image)
        if target is not None:
            target = torch.as_tensor(np.array(target), dtype=torch.int64)
        return image, target, label
    pass

# 可用 torchvision 里面的 compose, 为方便看过程,因此自己实现一遍
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms
    def __call__(self, image, mask=None, label=None):
        for t in self.transforms:
            image, mask, label = t(image, mask, label)
        return {'image':image, 'mask':mask, 'label':label}
    pass

if __name__=="__main__":
    # "1.创建一个 dataset 对象"
    train_dataset = SegmentDataset(MyLinesGetter(FilePath="data_flowers/train.txt", dtype="seg").getter(), 
                                   transforms=Compose([Resize((256,256)), ToTensor(),]))

    # "2.创建一个 dataloader 对象"
    train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)

    # "3.获取数据集的 mini_batch"
    for (images, masks) in train_data_loader:
        plt.figure(figsize=(20,20))
        plt.imshow(np.hstack(images.permute(0,2,3,1)))
        plt.show()
        plt.figure(figsize=(20,20))
        plt.imshow(np.hstack(masks))
        plt.show()
        break

参考链接

植物素材库
代码高亮
Pytorch dataset&dataloader
图像语义分割实践

上一篇下一篇

猜你喜欢

热点阅读