数据增强

2022-09-12  本文已影响0人  小黄不头秃

(一)数据增强(增广)

(1)为什么要做数据增强?

一个原因是可能你的数据集比较小,所以需要对数据进行简单的操作,让数据集增加。第二是有这样的一个真实实例。有一家做智能售货机的公司在公司内部调试好参数训练好模型以后,将售货机拿去展厅进行测试的时候,发现原本准确率非常高的机器忽然识别不出来了。原因是展厅的光源不一样,导致整个的测试数据集就和训练集发生了很大的变化。所以在产品研发的时候适当的通过数据增强技术能够给模型增加鲁棒性。

(2)数据增强方法有什么?

可通过在图片中加入各种不一样的背景噪音,改变图片的颜色和形状。

(二)代码实现

%matplotlib inline
import torch 
import torchvision
from torch import nn 
from d2l import torch as d2l
import matplotlib.image as img
import matplotlib.pyplot as plt

# 打开图片的方法
# image = img.imread('../img/cat1.jpg')
# plt.title("cat.jpg")
# plt.axis("off")
# plt.imshow(image)
# plt.show()

d2l.set_figsize()
img = d2l.Image.open('../img/cat1.jpg')
d2l.plt.imshow(img)
# 参数列表(图片,增强的办法,多少行,多少列,倍数)
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
    y = [aug(img) for _ in range(num_rows*num_cols)]
    d2l.show_images(y, num_rows, num_cols, scale=scale)

# 左右翻转图片
apply(img,torchvision.transforms.RandomHorizontalFlip())

# 上下翻转
apply(img,torchvision.transforms.RandomVerticalFlip())

# 随即裁剪
shape_aug = torchvision.transforms.RandomResizedCrop(
    # (输出大小,选择的比例,高宽比)
    size=(200,200),scale=(0.1,1),ratio=(0.5,2))
apply(img,shape_aug)
# 随机更改图片的亮度
apply(img,torchvision.transforms.ColorJitter(
    # (亮度区间,对比度,饱和度,色调)
    brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5
))
# 结合多种数据增强方法
augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    shape_aug,
    torchvision.transforms.ColorJitter(brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5)
])

apply(img, augs)
# 如果下载报错的话,自己去网页上下载
# https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
all_images = torchvision.datasets.CIFAR10(
    train=True,
    root="../data",
    download=False
)

d2l.show_images([all_images[i][0] for i in range(32)],4,8,scale=0.8)
# d2l.show_images([all_images.data[i] for i in range(32)],4,8,scale=0.8)
train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor()
])
test_augs = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
def load_cifar10(is_train,augs, batch_size):
    dataset = torchvision.datasets.CIFAR10(
        train=is_train,
        root="../data",
        download=False,
        transform=augs,
    )
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0
    )
    return data_loader
def train_batch_ch13(net, X, y, loss, trainer, devices):
    """用多GPU进行小批量训练"""
    if isinstance(X, list):
        # 微调BERT中所需(稍后讨论)
        X = [x.to(devices[0]) for x in X]
    else:
        X = X.to(devices[0])
    y = y.to(devices[0])
    net.train()
    trainer.zero_grad()
    pred = net(X)
    l = loss(pred, y)
    l.sum().backward()
    trainer.step()
    train_loss_sum = l.sum()
    train_acc_sum = d2l.accuracy(pred, y)
    return train_loss_sum, train_acc_sum
#@save
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
               devices=d2l.try_all_gpus()):
    """用多GPU进行模型训练"""
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                            legend=['train loss', 'train acc', 'test acc'])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        # 4个维度:储存训练损失,训练准确度,实例数,特点数
        metric = d2l.Accumulator(4)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = train_batch_ch13(
                net, features, labels, loss, trainer, devices)
            metric.add(l, acc, labels.shape[0], labels.numel())
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0] / metric[2], metric[1] / metric[3],
                              None))
        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {metric[0] / metric[2]:.3f}, train acc '
          f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
          f'{str(devices)}')
batch_size, devices, net = 256, d2l.try_all_gpus(), d2l.resnet18(10, 3)

def init_weights(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_uniform_(m.weight)

net.apply(init_weights)

def train_with_data_aug(train_augs, test_augs, net, lr=0.001):
    train_iter = load_cifar10(True, train_augs, batch_size)
    test_iter = load_cifar10(False, test_augs, batch_size)
    loss = nn.CrossEntropyLoss(reduction="none")
    trainer = torch.optim.Adam(net.parameters(), lr=lr)
    train_ch13(net, train_iter, test_iter, loss, trainer, 10, devices)
train_with_data_aug(train_augs, test_augs, net)
上一篇下一篇

猜你喜欢

热点阅读