Multi-GPU training in Pytorch

2021-03-08  本文已影响0人  Birdy潇

1. Use DDP

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
import torch.nn.parallel.DistributedDataParallel as ddp
dist.init_process_group(backend='nccl')
local_rank = dist.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device('cuda',local_rank) 
# move model to GPU
model = mymodel()
model = model.to(device) # move model to device first
model = ddp(model, device_ids = [local_rank]) # model with ddp, used for ddp training
model_without_ddp = model.module # model without ddp, used for save model and its weights

# wrap up dataloader 
dataset = mydataset()
data_loader = DataLoader(dataset,  batchsize=16,
                         sampler=torch.utils.data.distributed.DistributedSampler(dataset),
                         collate_fn=collate_fn, num_workers=4, pin_memory=True)
for input, target in data_loader:
    input = input.to(device)
    target = target.to(device)

be careful when save the model , we need to save on the main process!! and save model without ddp

# save on the master
if dist.get_rank() == 0:
    torch.save({'model':model_without_ddp.state_dict(),
                'optimizer':optimizer.state_dict()})

use prefetch to pre-load data, install the prefetch_generator package if you don't havepip install prefetch_generator

write a new DataLoader class

from torch.utils.data import DataLoader
from prefetch_generator import BackgroundGenerator
class DataLoaderX(DataLoader):
    def __iter__(self):
        return BackgroundGenerator(super(DataLoaderX, self).__iter__())

then replace torch's DataLoader with DataLoaderX class

上一篇 下一篇

猜你喜欢

热点阅读