torchvision包
2018-07-07 本文已影响243人
FantDing
title: torchvision包
date: 2018-07-09 10:13:47
tags:
- torchvision
categories:
- pytorch
- torchvision
1. 子模块/包
-
dataset包
其下的类都继承自
torch.utils.data.Dataset
-
utils
这里的
torchvison.utils
只处理图像,而torch.utils.data
有一个重要的class:DataLoader
-
transforms
进行图像的变换,用于增广数据集
-
models
直接使用经典的网络结构(也可加载预训练参数)
2. utils
2.1. torchvision.utils.make_grid(tensors)
将tensors合并成tensor。tensor.numpy()为BRG模式的图片。官网API
- 参数
- tensors: [ BATCH×C×H×W ]
- 返回
- tensor: [C×H×W]
- 注意
- 不论传入的图片们(tensors)的通道数
C
是1 or 3
, 返回tensor的通道数都是3
- 不论传入的图片们(tensors)的通道数
3. transforms
3.1. transforms.Compose[list]
核心:对于传入的PIL image每次transform之后,将结果传入到下次transform操作中
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
3.1.1. PIL图像
- 输入需是PIL图像,其size是
[Width, Height]
。因此transforms.Compose(list)
是有顺序的transforms.Compose([ transforms.RandomResizedCrop(224), # 输入PIL image transforms.RandomHorizontalFlip(), # 输入PIL image transforms.ToTensor(), # 放在最后,将PIL image(size是W,H)转换成Tensor(size是C,H,W) transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])
- PIL图像的转换
import torchvision.transforms as transforms import torch import numpy as np from PIL import Image if __name__=="__main__": # 1. 读取成PIL image img=Image.open("./test.jpeg") print(img.size) # (3840,2160)=>(width, height) # 2. PIL转换成ndarray np_img=np.array(img) print(np_img.shape) # (2160, 3840, 3) => (height,width,c) # 3. PIL转换成tensor totensor=transforms.ToTensor() tensor_img=totensor(img) print(tensor_img.size()) # torch.Size([3, 2160, 3840]) # 4. tensor转换成PIL topil=transforms.ToPILImage() pil=topil(tensor_img) print(pil.size) # (3840, 2160)
- 之所以可以将PIL转换成
ndarray
,是因为PIL Image
是array_like
的,具体见stackoverflow
- 之所以可以将PIL转换成
3.2. transforms.RandomResizedCrop(targetsize)
Crop the given PIL Image to random size and aspect ratio. Then, this crop is finally resized to given size.[选定随机的面积 and 这个面积的纵横比,来裁剪PIL图像。最后将裁剪好的图像resize到高、宽都为targetsize]
3.2.1.核心
def __call__(self, img):
# (i,j)左上角坐标
i, j, h, w = self.get_params(img, self.scale, self.ratio)
# 先对img进行crop,再通过self.interpolation插值成self.size
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
3.3. Resize()
- 参数para
- 二维sequence like:则缩放到给定的size
- int: 在等比例缩放的前提下,将最小边缩放到para。 i.e, if height > width, then image will be rescaled to (size * height / width, size)
3.4. ToTensor()
将
[0, 255]
的PIL image或者ndarray(H * W * C)转换成[0.0, 1.0]
的Tensor
4. datasets
from torchvision import datasets
构造自己的/已有的数据集
4.1. 公共点
-
datasets
模块下所有的类(ImageFolder
,mnist
等)都继承自torch.util.data.Dataset
- 因此也常常通过
torch.utils.data.DataLoader
辅助加载数据
- 因此也常常通过
- 构造函数,都可以传入
transform
4.2. ImageFolder
- 参数
- root: Root directory path。组织形式如下:
# root_dir/class_name/*.[png|jpg...] root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png
- transform: 对于输入图片的变化,可用于数据增广
- 所有的
datasets
下面的类,都接受transform
- 所有的
- root: Root directory path。组织形式如下:
- 属性
- classes (list): List of the class names.
- class_to_idx (dict): Dict with items (class_name, class_index).
- samples (list): List of (sample path, class_index) tuples
4.3. 辅助类torch.utils.data.DataLoader
可以将
Dataset
传入