PyTorch-混合精度训练(Mixed-Precisioin

2023-06-30  本文已影响0人  侠之大者_7d3f

混合精度训练介绍

Mixed-Precision Training是指在深度学习AI模型训练过程中不同的层Layer采用不同的数据精度进行训练, 最终使得训练过程中的资源消耗(GPU显存,GPU 算力)降低, 同时保证训练可收敛,模型精度与高精度FP32的结果接近。


CNN ResNet 混合精度训练

  1. 导入torch.cuda.amp package
    由于CNN训练要求大量算力, 因此一般混合精度需要使用 NVIDIA Automatic Mixed Precision (AMP)包, NVIDIA的AMP以及集成到了Pyorch, 因此直接调用 torch.cuda.amp APIs.

混合精度主要用到 Loss-Scaling (损失缩放) + Auto-cast (自动精度选择/转换)

# Mixed-Precision Training
from torch.cuda.amp.grad_scaler import GradScaler
from torch.cuda.amp.autocast_mode import autocast

# 实例化一个GradeScaler对象
scaler = GradScaler()
  1. 对Training Loop进行修改, 修改2个地方
 for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        # move data to the same device as model
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # compute output
        with autocast(enabled=args.mixed_precision, dtype=torch.float16):
                output = model(images)
                loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        # losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
         ...

为了便于观察训练过程, 在代码中添加了Pytorch Profiler进行可视化:
为了说明情况, 可以只跑少数的几个batch即可

image.png image.png image.png

对应的CUDA kernel函数, FP16类型的


image.png

采用BFLOAT16进行混合精度

方法很简单,在autocast的dtype设置为torch.bfloat16。 除此之外,需要采用支持BFLOAT16类型的计算设备(TPU, >=NVIDIA Ampere/Volta 架构的GPU, 比如NVIDIA V100, A100, RTX 30/40系列)

with autocast(enabled=args.mixed_precision, dtype=torch.bfloat16):
                output = model(images)
                loss = criterion(output, target)
image.png
image.png
上一篇 下一篇

猜你喜欢

热点阅读