图像分类

pytorch 数据加载

2021-05-24  本文已影响0人  1037号森林里一段干木头

简介:pytorch提供的加载数据集的两个工具包Dataset和DataLoader,对Dataset进行简单的改造就可以加载自己的数据了。

1.Dataset 类

要使用pytorch提供的Dataset类,需要重写两个主要的方法

也就是基本所有的自定义数据加载类看起来都会是这样的

from torch.utils.data import Dataset
class myDataset(Dataset):
    def __init__(self,):
        pass
    
    def __len__(self):
        pass
    
    def __getitem__(self, idx):
        pass

2. 自定义Dataset示例:

以kaggle竞赛中的猫狗分类数据集为例,在训练集中包含12500张猫的图片和12500张狗的图片,文件名为: class.index.jpg

image.png
重写Dataset类的init,len,getitem方法
class CatsAndDogs(Dataset):
    def __init__(self, root,transforms=None,size=(224,224)):
        #初始化
        self.images = [os.path.join(root,item) for item in os.listdir(root)]
        self.transforms = transforms
        self.size = size

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        #这里需要resize是因为用Dataloader加载的同一个batch里面的图片大小需要一样
        image = np.array(Image.open(self.images[idx]).resize(self.size))
        #the format of the path :"K:\\imageData\\dogAndCat\\train\\dog.9983.jpg"
        label = self.images[idx].split("\\")[-1].split(".")[0]
        return image,label

3.DataLoader

重写完Dataset之后使用DataLoader获得batch数据。

trainLoader = DataLoader(mydataset,batch_size=32,num_workers=2,shuffle=True)
for images,labels in trainLoader:
    pass

DataLoader原型
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)

4.完整示例

from torch.utils.data import Dataset,DataLoader
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

class CatsAndDogs(Dataset):
    def __init__(self, root,transforms=None,size=(224,224)):
        #初始化

        self.images = [os.path.join(root,item) for item in os.listdir(root)]
        self.transforms = transforms
        self.size = size

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        #这里需要resize是因为用Dataloader加载的同一个batch里面的图片大小需要一样
        image = np.array(Image.open(self.images[idx]).resize(self.size))
        #the format of the path :"K:\\imageData\\dogAndCat\\train\\dog.9983.jpg"
        label = self.images[idx].split("\\")[-1].split(".")[0]
        return image,label

if __name__ == "__main__":
    mydataset = CatsAndDogs(r"K:\imageData\dogAndCat\train")
    dataloader = DataLoader(mydataset,batch_size=8,shuffle=True)
    num = 0

    for imgs,labels in dataloader:
        print(labels)
        print(imgs.size())
        num += 1
        if num > 2:
            break

    for i in range(8):
        ax = plt.subplot(3,3,i+1)
        ax.imshow(imgs[i])
        ax.set_title(labels[i])
        ax.axis("off")
    # plt.imshow(imgs[1])
    plt.show()
('dog', 'dog', 'dog', 'dog', 'cat', 'dog', 'dog', 'cat')
torch.Size([8, 224, 224, 3])
('dog', 'cat', 'dog', 'dog', 'dog', 'cat', 'dog', 'dog')
torch.Size([8, 224, 224, 3])
('dog', 'cat', 'cat', 'dog', 'cat', 'dog', 'cat', 'dog')
torch.Size([8, 224, 224, 3])
image.png
上一篇 下一篇

猜你喜欢

热点阅读