看pytorch official tutorials的新收获(

2019-11-14  本文已影响0人  赖子啊
pytorch标志

2019/11/3
今天才正式开始看pytorch的官网教程,把一些基础的操作先搞明白吧,虽然之前跑过简单的demo,总是感觉有些地方不能解释得很好,所以这次记录一下新的收获:

1 torch tensor

pytorch里面最基础格式的量就是torch tensor了,但这么基础的量在初始化的时候,还是可以指定很多参数的,例如torch.tensor(data, dtype, device, requires_grad),一个tensor还有.grad和.grad_fn属性。

2 in-place operation

原文:

Any operation that mutates a tensor in-place is post-fixed with an_. For example: x.copy_(y), x.t_(), will change x.

3 numpy vs tensor

原文:numpy的数据和torch tensor格式是可以相互转化的,用tensor.numpy()和torch.from_numpy()就可以互相转化,不过得注意下面一点:

The Torch Tensor and NumPy array will share their underlying memory locations (if the Torch Tensor is on CPU), and changing one will change the other.

4 关于torch.nn建立网络

原文:以前没有留意这问题,不过一般也都符合这个输入要求,就是必须是有batch的那个维度

torch.nn only supports mini-batches. The entire torch.nn package only supports inputs that are a mini-batch of samples, and not a single sample.
If you have a single sample, just use input.unsqueeze(0) to add a fake batch dimension.

虽然以前知道可以print(net)来查看网络的结构,但是不知道还可以用net.conv1来对第一个卷积核进行操作(当然前提是你在init函数里面已经定义了conv1)

5 torch.utils.data

Dataset

torch.utils.data.Dataset,附上原文的一段介绍(因为感觉翻译没有原来的味道):

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite__len__(), which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader.

主要就是有自己训练集的时候,可以构建一个类来继承Dataset类,并且重写__getitem__()方法和__len__()方法,这样也就可以用后面的的Dataloader来加载数据集了。

DataLoader

torch.utils.data.DataLoader是pytorch里面核心的数据加载的模块,他的接口是这样的,看到可选的参数很多,常用的有dataset, batch_size, shuffle, num_workers

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, 
                                                           sampler=None, batch_sampler=None, num_workers=0, 
                                                           collate_fn=None, pin_memory=False, drop_last=False, 
                                                           timeout=0, worker_init_fn=None, multiprocessing_context=None)

6 torchvision

datasets

torchvision.datasets里面有很多的可以加载的数据集,如MNIST、Fashion-MNIST、KMNIST、EMNIST、QMNIST、FakeData、COCO(Captions\Detection)、LSUN、ImageNet、CIFAR、STL10、SVHN、PhotoTour、SBU、Flickr、VOC、Cityscapes、SBD、USPS、Kinetics-400、HMDB51、UCF101。这些都继承了torch.utils.data.Dataset这个类,所以这些数据集都可以用torch.utils.data.DataLoader的多线程来进行快速的加载(如果我们自己构建自己的dataset,去重写lengetitem方法,也可以调用torch.utils.data.DataLoader来对数据进行加载),而且他们的API接口都很像,差不多都有下面几个参数(以ImageNet为例,不用解释,一看能猜出来):

imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)

当然里面也有一个通用的接口,可以让你自己构建数据集:

torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)
只要你按照下面的文件目录结构存放自己的图像就行
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

这里面还有一些是比较常用的数量,一般情况下我们会构建两个dataset,一个是训练的,一个是测试的,如trainset,valset;与之对应就有两个DataLoader,一个是trainloader,一个是valloader:

for data in trainloader:
    batch_img, batch_target = data
# 或者有的地方也这样写
dataiter = iter(trainloader)
images, labels = dataiter.next()

transforms

torchvision.transforms主要用于对图像进行变换,也就是图像增强data augmentation。

最后可以用 torchvision.transforms.Compose(transforms)把想要对图像做的变换都集中起来,transforms是各种之前特定变换操作的列表

small summary

所以结合6中两个就可以差不多这样写

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

Common import

import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets

Note

The output of torchvision datasets are PILImage images of range [0, 1].

More

See here for more details on saving PyTorch models.
If you want to see even more MASSIVE speedup using all of your GPUs, please check out Optional: Data Parallelism.

上一篇 下一篇

猜你喜欢

热点阅读