(6)自定义数据集

2018-12-04  本文已影响0人  顽皮的石头7788121

    PyTorch提供了一个工具函数torch.utils.data.DataLoader。通过这个类,我们在准备mini-batch的时候可以多线程并行处理,这样可以加快准备数据的速度。Datasets就是构建这个类的实例的参数之一。准备数据的代码一般为data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)。datasets.CIFAR10就是一个Datasets子类,data是这个类的一个实例。

    如何自定义Datasets

下面是一个自定义Datasets的框架

class CustomDataset(data.Dataset):#需要继承data.Dataset

    def __init__(self):

        # TODO

        # 1. Initialize file path or list of file names.

        pass

    def __getitem__(self, index):

        # TODO

        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).

        # 2. Preprocess the data (e.g. torchvision.Transform).

        # 3. Return a data pair (e.g. image and label).

        #这里需要注意的是,第一步:read one data,是一个data

        pass

    def __len__(self):

        # You should change 0 to the total size of your dataset.

        return 0

下面看一下官方MNIST的例子(代码被缩减,只留下了重要的部分):

class MNIST(data.Dataset):

    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):

        self.root = root

        self.transform = transform

        self.target_transform = target_transform

        self.train = train  # training set or test set

        if download:

            self.download()

        if not self._check_exists():

            raise RuntimeError('Dataset not found.' +

                              ' You can use download=True to download it')

        if self.train:

            self.train_data, self.train_labels = torch.load(

                os.path.join(root, self.processed_folder, self.training_file))

        else:

            self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))

    def __getitem__(self, index):

        if self.train:

            img, target = self.train_data[index], self.train_labels[index]

        else:

            img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets

        # to return a PIL Image

        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:

            img = self.transform(img)

        if self.target_transform is not None:

            target = self.target_transform(target)

        return img, target

    def __len__(self):

        if self.train:

            return 60000

        else:

            return 10000


上一篇下一篇

猜你喜欢

热点阅读