pytorch学习笔记-dataloader输出不同尺寸的输入图

2020-08-18  本文已影响0人  升不上三段的大鱼

pytorch可以自己定义 Dataset类, 然后用dataloader 函数来获取输入以及对应标签。下面是个简单的例子:

from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader

class TrainDataset(Dataset):
    def __init__(self, root_dir, csv_file, transform):

        self.root_dir = root_dir
        self.labels = pd.read_csv(csv_file)
        self.transform = transform
        
    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, index):
        img_name = os.path.join(self.root_dir,
                                self.labels.iloc[index, 0])
        image = Image.open(img_name+'.jpg')
        label = self.labels.iloc[index,1:].astype(int).to_numpy()
        label = np.argmax(label)

        if self.transform:
            image = self.transform(image)

        return image, label

 dataset = TrainDataset(
        root_dir='./data/Input',
        csv_file=csv_file,
        transform=transforms.Compose([
           transforms.Resize(224, 224),
           transforms.HorizontalFlip(p=0.5),
           transforms.VerticalFlip(p=0.5),
           transforms.Rotate(limit=(-90,90)),
           transforms.RandomBrightnessContrast(),
        ])
    )

data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                              shuffle=True,num_workers=1)

for inputs, labels in data_loader:
     img = torchvision.utils.make_grid(inputs[0])
     img_nm = img.numpy()
     img_trans = np.transpose(img_nm, (1, 2, 0))
     plt.imshow(img_trans)
     plt.show()

这样就可以使用自己定义的数据集了。

但是如果想要让数据集保持自己原来的尺寸,也就是说如果不用 transforms.Resize(224, 224), 把图片都缩放到224,而是保持他们原来各自不同的尺寸,需要怎么做呢?

只需要加一个自定义的collate_fn函数就可以了。在默认情况下,pytorch将图片叠在一起,成为一个NCH*W的张量,因此每个batch里的每个图像必须是相同的尺寸。所以如果想要接受不同尺寸的输入图片,我们就要自己定义collate_fn。
对于图像分类,collate_fn的输入大小是batch_size 大小的list, list里每个元素是一个元组,元组里第一个是图片,第二个是标签。对于不同大小的输入图片,我们可以使用list来储存。具体实现如下(Dataset类里面去掉resize):

def my_collate(batch):
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    target = torch.LongTensor(target)
    return [data, target]

data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                              shuffle=True, collate_fn =my_collate)
trainiter = iter(data_loader)
imgs, labels = trainiter.next()

然后就可以得到保留了原尺寸的图片了。
不过要注意这里得到的 imgs是一个list,用的时候注意数据类型。

上一篇 下一篇

猜你喜欢

热点阅读