Pytorch基础(一):Data Loading and Pr

2019-04-03  本文已影响0人  csuhan

解决机器学习问题中很多的工作都是在处理数据。Pytorch提供许多工具,是的数据操作更加简便有用,使代码更加的易读。

Dataset class

Dataset位于torch.units.data.Dataset,是一个抽象类,用于代表数据集。我们可以继承它,然后重写以下两个方法:

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

Transforms

addon:torchvision包含三个模块:dataset包含常见数据集,models包含常见模型,tranforms用于对影像进行变换。
Tranforms位于torchvision.transforms,即对数据集进行变换,如影像数据集的裁减,缩放等。

  1. transforms.Compose将若干transforms操作集合起来,如
transforms.Compose([transforms.CenterCrop(10),transforms.ToTensor(),])

CenterCropToTensor结合起来。
小示例如:

import torchvision.transforms as transforms

In [6]: trans1 = transforms.Compose([transforms.ToTensor(),])

In [7]: img = cv2.imread('elephant.png')

In [8]: img1 = trans1(img)

In [9]: type(img1)
Out[9]: torch.Tensor

In [10]: type(img)
Out[10]: numpy.ndarray

In [11]: trans2 = transforms.Compose([transforms.ToPILImage(),])

In [12]: img2 = trans2(img)

In [13]: type(img2)
Out[13]: PIL.Image.Image
  1. Transforms on PIL Image。包含了诸如图片裁剪,变换等操作。
  2. Ttransforms on torch.*Tensor。线性变换,正则化等操作。
  3. Conversion Transforms。包含了向Tensor或者向PIL Image转换的transforms
  4. Generic Transforms。更加通用的转换方法。

DataLoader

即将标准的Dataset装载进DataLoader,从而实现mini batch,shuffle等便捷操作。【时间有限,下次更】

上一篇 下一篇

猜你喜欢

热点阅读