pytorch笔记01-数据增强

2019-08-18  本文已影响0人  坐下等雨

1、什么是数据增强

数据增强是扩充数据样本规模的一种有效地方法。深度学习是基于大数据的一种方法,我们当前希望数据的规模越大、质量越高越好。模型才能够有着更好的泛化能力,然而实际采集数据的时候,往往很难覆盖掉全部的场景,比如:对于光照条件,在采集图像数据时,我们很难控制光线的比例,因此在训练模型的时候,就需要加入光照变化方面的数据增强。再有一方面就是数据的获取也需要大量的成本,如果能够自动化的生成各种训练数据,就能做到更好的开源节流。

2、数据增强的作用

3、如何进行数据增强

数据增强可以分为两类,一类是离线增强,一类是在线增强。

4、pytorch数据增强操作

pytorch中数据增强的常用方法如下:

torchvision中内置的transforms包含了这些些常用的图像变换,这些变换能够用Compose串联组合起来。

from PIL import Image
from torchvision import transforms as tfs

img = Image.open('./dog.jpg')
print('原图:')
img

原图:


4.1、中心处裁剪PIL图片

class torchvision.transforms.CenterCrop(size)

print('原图像尺寸:{}'.format(img.size))
re_img = tfs.CenterCrop(200)(img)
print('中心裁剪后尺寸:{}'.format(re_img.size))
re_img

原图像尺寸:(658, 411)
中心裁剪后尺寸:(200, 200)


4.2 随机改变图片的亮度、对比度和饱和度

class torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)

cj_img = tfs.ColorJitter(0.8, 0.8, 0.5)(img)
cj_img

4.3 图片转换为灰阶

class torchvision.transforms.Grayscale(num_output_channels=1))

gc_img = tfs.Grayscale(1)(img)
gc_img

4.4 图像的各条边缘进行扩展

class torchvision.transforms.Pad(padding, fill=0, padding_mode='constant')

# 用常数0填充
con_img = tfs.Pad(50, fill=0, padding_mode='constant')(img)
con_img
# 用图像边缘值填充
edge_img = tfs.Pad(50, fill=0, padding_mode='edge')(img)
edge_img
# 以边缘为对称轴进行轴对称填充
ref_img = tfs.Pad(50, fill=0, padding_mode='reflect')(img)
ref_img

4.5 图片在随机位置处进行裁剪

class torchvision.transforms.RandomCrop(size, padding=0, pad_if_needed=False)

rc_img = tfs.RandomCrop(200)(img)
rc_img

4.6 以给定的概率随机水平翻折PIL图片

class torchvision.transforms.RandomHorizontalFlip(p=0.5)

rh_img = tfs.RandomHorizontalFlip(1)(img)
rh_img

4.7 以给定的概率随机垂直翻折PIL图片

class torchvision.transforms.RandomVerticalFlip(p=0.5)

rv_img = tfs.RandomVerticalFlip(1)(img)
rv_img

4.8 以指定的角度选装图片

class torchvision.transforms.RandomRotation(degrees, resample=False, expand=False, center=None)

rr_img = tfs.RandomRotation(45)(img)
rr_img

以上都是对图像做单次变换,torchvision提供torchvision.transforms.Compose()函数,可以将以上图像方法联合起来使用,比如先做随机翻转,然后随机截取,再做对比度增强等。
import matplotlib.pyplot as plt
%matplotlib inline

aug_img = tfs.Compose([
    tfs.Resize(200),
    tfs.RandomHorizontalFlip(),
    tfs.RandomCrop(120),
tfs.RandomVerticalFlip(),
    tfs.ColorJitter(0.5, 0.5, 0.5)
])

_, figs = plt.subplots(3, 3, figsize=(10, 10))
for i in range(3):
    for j in range(3):
        figs[i][j].imshow(aug_img(img))
        figs[i][j].axes.get_xaxis().set_visible(False)
        figs[i][j].axes.get_yaxis().set_visible(False)
上一篇 下一篇

猜你喜欢

热点阅读