[PyTorch]多卡运行(transformer-xl)
2019-03-31 本文已影响0人
VanJordan
原理
- 多
GPU
运行的接口是torch.nn.DataParallel(module, device_ids)
其中module
参数是所要执行的模型,而device_ids
则是指定并行的GPU id
列表。 - 而其并行处理机制是,首先将模型加载到主
GPU
上,然后再将模型复制到各个指定的从GPU
中,然后将输入数据按batch
维度进行划分,具体来说就是每个GPU
分配到的数据batch
数量是总输入数据的batch
除以指定GPU
个数。每个GPU
将针对各自的输入数据独立进行forward
计算,最后将各个GPU
的loss
进行求和,再用反向传播更新单个GPU
上的模型参数,再将更新后的模型参数复制到剩余指定的GPU
中,这样就完成了一次迭代计算。所以该接口还要求输入数据的batch
数量要不小于所指定的GPU
数量。 -
DataParallel
自动地分割输入数据,同时将他们发送到每个GPU
的模型中. 当模型处理完成后,DataParallel
会将各个设备中的处理结果收集和合并,再返回给用户。
示意图
需要注意
- 主
GPU
默认情况下是0
号GPU
,也可以通过torch.cuda.set_device(id)
来手动更改默认GPU
。 - 提供的多
GPU
并行列表中需要包含有主GPU
。 - 但是,
DataParallel
有一个问题:GPU
使用不均衡。在一些设置下,主GPU
会比其他GPU
使用率高得多。
例子
- 构建多
GPU
的DataParallel
if args.multi_gpu:
model = model.to(device)
if args.gpu0_bsz >= 0:
para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
model, dim=1).to(device)
else:
para_model = nn.DataParallel(model, dim=1).to(device)
else:
para_model = model.to(device)
-
在正向传播的时候使用para_model,其他的时候,比如使用模型的参数可以直接调用
model.parameters()
-
均衡的
DataParallel
class BalancedDataParallel(DataParallel):
def __init__(self, gpu0_bsz, *args, **kwargs):
self.gpu0_bsz = gpu0_bsz
super().__init__(*args, **kwargs)
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
if self.gpu0_bsz == 0:
device_ids = self.device_ids[1:]
else:
device_ids = self.device_ids
inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids)
if self.gpu0_bsz == 0:
replicas = replicas[1:]
outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
return self.gather(outputs, self.output_device)
def parallel_apply(self, replicas, device_ids, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, device_ids)
def scatter(self, inputs, kwargs, device_ids):
bsz = inputs[0].size(self.dim)
num_dev = len(self.device_ids)
gpu0_bsz = self.gpu0_bsz
bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
if gpu0_bsz < bsz_unit:
chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
delta = bsz - sum(chunk_sizes)
for i in range(delta):
chunk_sizes[i + 1] += 1
if gpu0_bsz == 0:
chunk_sizes = chunk_sizes[1:]
else:
return super().scatter(inputs, kwargs, device_ids)
return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)