Pytorch 数据加载器: Dataset 和 DataLoa

2019-07-19  本文已影响0人  0过把火0

为什么要用?

习惯于自己实现业务逻辑的每一步,以至于没有意识去寻找框架本身自有的数据预处理方法,Pytorch的Dataset 和 DataLoader便于加载和迭代处理数据,并且可以傻瓜式实现各种常见的数据预处理,以供训练使用。

调包侠

from torch.utils.data.dataset import Dataset, DataLoader
from torchvision import transforms  ##可方便指定各种transformer,直接传入DataLoader

Dataset 和 DataLoader是什么?

Dataset是一个包装类,可对数据进行张量(tensor)的封装,其可作为DataLoader的参数传入,进一步实现基于tensor的数据预处理。

如何处理自己的数据集?

很多教程里分两种情况:数据同在一个文件夹;数据按类别分布在不同文件夹。其实刚开始我是一头雾水,后来总结后发现,两种情况均可用一种方法来处理,即:只要有一份文件,记录图像数据路径及对应的标签即可,如下所示:

record.txt 示例:
    pic_path                          label
./pic_01/aaa.bmp                        1
./pic_22/bbb.bmp                        0
./pic_03/ccc.bmp                        3
./pic_01/ddd.bmp                        1
            ...

其实有了上面的一份数据对照表文件,即可不用管是否在同一文件夹或是不同文件夹的情况,我自己感觉是要方便一些。下面就按照这种方法来介绍如何使用。

第一步:实现MyDataset类

既然是要处理自己的数据集,那么一般情况下还是写一个自己的Dataset类,该类要继承Dataset,并重写 __ init __() 和 __ getitem __() 两个方法。

例如:
class MyDataset(Dataset):
    def __init__(self, record_path, is_train=True):
        ## record_path:记录图片路径及对应label的文件
        self.data = []
        self.is_train = is_train
        with open(record_path) as fp:
            for line in fp.readlines():
                if line == '\n':
                    break
                else:
                    tmp = line.split("\t")
                    ## tmp[0]:某图片的路径,tmp[1]:该图片对应的label
                    self.data.append([tmp[0], tmp[1]])
        # 定义transform,将数据封装为Tensor
        self.transformations = transforms.Compose([transforms.ToTensor()])

    # 获取单条数据
    def __getitem__(self, index):
        img = self.transformations (Image.open(self.data[index][0]).resize((256,256)).convert('RGB'))
        label = int(self.data[index][1])
        return img, label

    # 数据集长度
    def __len__(self):
        return len(self.data)

上面是一个简单的MyDataset类,仅依赖记录了图像位置以及相应label的record文件,实现对数据集的读取和Tensor的转换

当然,根据个人对数据预处理的需求不同,该类的实现可进一步完善,例如:

class MyDataset(Dataset):
    def __init__(self, base_path, is_train=True):
        self.data = []
        self.is_train = is_train
        with open(base_path) as fp:
            for line in fp.readlines():
                if line == '\n':
                    break
                else:
                    tmp = line.split("\t")
                    self.data.append([tmp[0], tmp[1]])
        ## transforms.Normalize:对R G B三通道数据做均值方差归一化,因此给出下方三个均值和方差
        normMean = [0.49139968, 0.48215827, 0.44653124]
        normStd = [0.24703233, 0.24348505, 0.26158768]
        normTransform = transforms.Normalize(normMean, normStd)
        ## 可由 transforms.Compose([transformer_01, transformer_02, ...])实现一些数据的处理和增强
        self.trainTransform = transforms.Compose([       ## train训练集处理
            transforms.RandomCrop(32, padding=4),        ## 图像裁剪的transforms
            transforms.RandomHorizontalFlip(p=0.5),      ## 以50%概率水平翻转
            transforms.ToTensor(),                       ## 转为Tensor形式
            normTransform                                ## 进行 R G B数据归一化
        ])
        ## 测试集的transforms数据处理
        self.testTransform = transforms.Compose([  
            transforms.ToTensor(),
            normTransform
        ])

    # 获取单条数据
    def __getitem__(self, index):
        img = self.trainTransform(Image.open(self.data[index][0]).resize((256,256)).convert('RGB'))
        if not self.is_train:
            img = self.testTransform(Image.open(self.data[index][0]).resize((256, 256)).convert('RGB'))
        label = int(self.data[index][1])
        return img, label

    # 数据集长度
    def __len__(self):
        return len(self.data)

或许已经看出来了,所有可能的数据处理或数据增强操作,都可通过transforms来进行调用与封装,是不是一下变得很方便呢!

第二步:将MyDataset装入DataLoader中

MyDataset类中的init方法要求传入记录数据路径及label的文件,因此可如下所示进行操作:

import MyDataset
train_data = MyDataset.MyDataset("./train_record.txt")
test_data = myDataset.MyDataset("./test_record.txt")
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
trainLoader = DataLoader(dataset=train_data,batch_size=64,shuffle=True,**kwargs)
testLoader = DataLoader(dataset=test_data,batch_size=64,shuffle=False, **kwargs)

这样,便生成了trainLoader 和testLoader

第三步:在训练中使用DataLoader
for epoch in range(1, args.nEpochs + 1):
     ## 定义好的train方法
     train(args, epoch, model, trainLoader, optimizer)
     ## 定义好的val方法,用于测试或验证
     val(args, epoch, model, testLoader, optimizer)

最后

以上便是使用 Dataset和DataLoader处理自己数据集的通用方法,当然本次仅记录了图片数据的使用方法,后续记录文本数据处理方法。

彩蛋

ooh~~ 那么对于Pytorch自带数据集如果处理呢?
若直接使用 CIFAR10 数据集,可以如下处理:

import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]
normTransform = transforms.Normalize(normMean, normStd)
trainTransform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normTransform
    ])
testTransform = transforms.Compose([
        transforms.ToTensor(),
        normTransform
    ])

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
trainLoader = DataLoader(dset.CIFAR10(root='cifar', train=True, download=True,
                     transform=trainTransform),batch_size=64, shuffle=True, **kwargs)
testLoader = DataLoader(dset.CIFAR10(root='cifar', train=False, download=True,
                     transform=testTransform),batch_size=64, shuffle=False, **kwargs)

其实也就是 torchvision.datasets将这些共用数据集本身就做了 Dataset类的封装,因此直接调用,传入你想要的transforms,再丢给DataLoader即可。

转载注明出处:https://www.jianshu.com/p/b558c538eac2

上一篇 下一篇

猜你喜欢

热点阅读