PyTorch

Pytorch中transform的常见用法

2020-03-02  本文已影响0人  geekboys

transforms模块详解

transforms是torchvision中的一个重要模块,它是Pytorch的图像预处理包,包含了很多种对图像数据进行变换的函数,这些都是我们加载训练数据步骤中必不可少的。比较常见的是下面的这部分代码:

data_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

Compose方法是将多种变换组合在一起。上述对data_transforms进行了四种变换,前两个是对PILImage进行的,分别对其进行随机大小和随机宽高比的裁剪,之后resize到指定大小224,以及对原始图像进行随机的水平翻转;
第三个transforms.ToTensor()将PILImage的转变为torch.FloatTensor的数据形式;最后一个Normalize则是对tensor进行的,不要问这些数值是怎么来的它们都是从ImageNet训练模型中总结出来的参数。下面需要着重强调一点是多种组合变换有一定的先后顺序,处理PILImage的变换方法(大多数方法)都需要放在ToTensor方法之前,而处理tensor的方法如上面的Normalize方法则要放在ToTensor方法之后。

transforms中的一些函数

output[channel]=(input[channel]-mean[channel])/std[channel]

例如:原来的tensor是三个维度,值在[0,1]之间,经过变换之后得到[-1,1]
计算如下:

((0,1)-0.5)/0.5=(-1,1)

transforms针对PILImage的操作还有很多

上一篇 下一篇

猜你喜欢

热点阅读