pytorch我爱编程

Pytorch学习(3): 常用工具模块简介

2018-04-04  本文已影响64人  月牙眼的楼下小黑

作 者: 月牙眼的楼下小黑
联 系zlf111@mail.ustc.edu.cn
声 明: 欢迎转载本文中的图片或文字,请说明出处


参考资料:

[1].PyTorch常用工具模块

1 数据处理

import torch 
from torch.utils import data
import os 
from PIL import Image
import numpy as np 
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage()

1.1 数据加载

Pytorch 中, 数据加载可通过自定义一个继承Dataset类的数据集对象, 并实现两个方法:

class DogCat(data.Dataset):
    def __init__(self, root):
        imgs = os.listdir(root)                               # root:图片所在文件夹路径
        self.imgs = [os.path.join(root, img) for img in imgs] # imgs:图片文件路径列表
        
    def __getitem__(self, index):
        img_path = self.imgs[index]
        if 'dog' in img_path.split('/')[-1]:
            label =1
        else: 
            label = 0
        pil_img= Image.open(img_path)            # 利用 python 图像处理标准库的 open 方法打开图片
        array = np.asarray(pil_img)              # 将 PIL.image 转化为 np. ndarray 形式, 默认为 channel last 形式: [height, width, channel]
        #array = np.transpose(array, (2, 0, 1))  # 将 channel last 形式转化成  channel first 形式:[channel, height, width]
        data = torch.from_numpy(array)           # 将 np.ndarray  转化为 Tensor 形式
        return data, label
    
    def __len__(self):
        return len(self.imgs)

补充: 对于三维矩阵的转置, 如 a.transpose(2,0,1), 意思是原矩阵a(aix 0, aix 1, aix 2) 处的值,现在成为了转置后矩阵 (aix 2, aix 0 , aix 1)处的值。

in:
dataset = DogCat('/data1/zhanglf/myDLStudying/myDataSet/dog_cat_data/train/dogs')
in:
# 显示第一张图片
img,label= dataset[0]
plt.imshow(img)                    # 若为 channel last 形式的 tensor, 可用 matplotlib 中 imshow() 方法
print(label, img.size(), img.float().mean())
out:
1 torch.Size([500, 282, 3]) 169.23073522458628
in:
# 显示第一张图片
img,label= dataset[0]
plt.imshow(img)                    # 若为 channel first 形式的 tensor, 可用 transforms 中 的 ToPILImage() 方法
print(label, img.size(), img.float().mean())

在前面文章中提到过:ToPILImage 可以将

转化成PIL.Image,值不变,方便可视化。注意到 它只能转变channel first形式的Tensor 。而在上面的__getitem__中,array = np.asarray(pil_img)PIL.image 转化为 np. ndarray 形式, 默认为 channel last 形式: [height, width, channel]。 所以如果我们要使用 ToPILImage 方法显示图片,在将 PIL.image 转化为 np. ndarray 形式后,还需要利用转置方法将 channel last 形式改成 channel first 形式:[channel, height, width]

1.2 数据预处理

torchvision.transforms模块提供了对 PILImage对象和Tensor对象的常用操作。

PILImage的操作包括:

Tensor的操作包括:

in:
trans = transforms.Resize((100,100))
image = Image.open('./dog.1.jpg')
print(image.size)
image = trans(image)
print(image.size)
out:
(327, 499)
(100, 100)

如果要对图片进行多个操作, 可通过Compose方法将这些操作拼接起来。

in:
transform = transforms.Compose([
    transforms.Resize(224),           # 缩放图片,保持长宽比不变,最短边为224像素
    transforms.CenterCrop(224),   # 从图片中间切出 224x224 的图片
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [1, 1, 1])
])

class DogCat(data.Dataset):
    def __init__(self, root, transforms = None):
        imgs = os.listdir(root)                                                # root:图片所在文件夹路径
        self.imgs = [os.path.join(root, img) for img in imgs] # imgs:图片文件路径列表
        self.transforms = transforms
        
    def __getitem__(self, index):
        img_path = self.imgs[index]
        if 'dog' in img_path.split('/')[-1]:
            label =1
        else: 
            label = 0
        data = Image.open(img_path)           # 利用 python 图像处理标准库的 open 方法打开图片
        if self.transforms:
            data = self.transforms(data)
        return data, label
    
    def __len__(self):
        return len(self.imgs)
in:
dataset = DogCat('/data1/zhanglf/myDLStudying/myDataSet/dog_cat_data/train/dogs', transforms = transform)
img,label= dataset[0]
print(img.size())
show(img)
out:
torch.Size([3, 224, 224])

1.3 ImageFolder

torchvision预先实现了常用的DataSet,如CIFAR-10, 可通过 torchvision.datasets.CIFAR10来调用。这里介绍一个经常使用的 DataSet——ImageFolder. ImageFolder 假设所有文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名, ImageFolder会根据文件夹名顺序自动生成 label, 可以通过 class_to_idx查看 label 和 文件夹名的映射关系。其构造函数如下:

ImageFolder(root, transform = None, target_transform = None, Loader = default_loader)
in:
from torchvision.datasets import ImageFolder
dataset = ImageFolder('./myDataSet/dog_cat_data/train')
in:
dataset.class_to_idx
out:
{'cats': 0, 'dogs': 1}
in:
# 此时还没有任何 transform, 返回的是 PILImage 对象
# 第一维指示第几张图片,第二维为 1 返回 label, 为 0 返回 图片数据
print(dataset[0][1]) 
dataset[0][0]
out:
0

1.4 DataLoader

调用DataSet中的__getitem__只返回一个样本,而我们需要batch wise trainingPytorch提供了 DataLoader帮助我们实现这些功能。其构造函数如下:

DataLoader(dataset, 
           batch_size=1, 
           shuffle = False,
           sample =None,
           sampler = None,
           num_workers =0,
           collate_fn = default_collate, 
           pin_memory =False,
           drop_last = False)

2. torchvision

torchvisionPytorch 团队开发的独立于 Pytorch的视觉工具包,通过pip install torchvision安装,主要包含三部分:

3. 可视化工具 Visdom

Visdom可以创造、组织和共享多种数据的可视化,包括数值、图像、文本、视频, 支持 PytorchNumpy。.

Visdom中的两个重要概念:

In:
import visdom
vis = visdom.Visdom(env=u'test1')    # 构建一个客户端对象,创建一个名为' test1' 的 env 
x = torch.arange(1, 30, 0.01)
y = torch.sin(x)
vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=sin(x)'})   #  win 是 pane 名字,opts 设置 pane 格式,如 title, xlabel,ylabel

vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=sin(x)'}) 中, win 参数指定 pane 名字, 如果不指定,visdom将自动分配一个新的pane. 如果两次操做指定的win名字一样,新操作将覆盖当前 pane 的内容。如在 上面的 pane中画 y = x 函数,原来的 y = sin(x) 将被覆盖。
In:
y = x
vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=x'})

如果不想覆盖原图,可以使用updateTrace方法,如:
y = x + 1
vis.updateTrace(X=x, Y=y, win='sinx', name='this is a new Trace')

未完待续

上一篇下一篇

猜你喜欢

热点阅读