tensorflow multi-gpu

2019-11-13  本文已影响0人  乘瓠散人

单机多卡,数据并行实现的关键点:

  1. 给每块卡都分配bs的数据
imgs, labels = next_batch(batch_size=cfgs.bs * num_gpus)
imgs_splits = tf.split(imgs, num_gpus)
labels_splits = tf.split(labels, num_gpus)
  1. 将模型参数副本分配到每张卡上,进行计算
gpus = ['0', '2']
tower_grads = []
tower_loss = []
tower_outs = []
opt = tf.train.AdamOptimizer(lr)
with tf.variable_scope(tf.get_variable_scope()):
    for i, d in enumerate(gpus):
        with tf.device('/gpu:%s' % d):
            with tf.name_scope('tower_%s' % d):
                # calculate the loss for one tower of the model. 
               # This function constructs the entire model but shares the variables across all towers
                loss, out = train_model(imgs_splits[i], labels_splits[i])
                # reuse variables for the next tower
                tf.get_variable_scope().reuse_variables()
                # calculate the gradients for the batch of data on this tower
                grads = opt.compute_gradients(loss)
                tower_grads.append(grads)

                tower_loss.append(loss)
                tower_outs.append(last_out)

# compute mean gradients
mean_grads = average_gradients(tower_grads)
# update the gradients to adjust the shared variables
train_op = opt.apply_gradients(mean_grads)
totloss = tf.reduce_mean(tower_loss)
# 将多张卡的所有bs的数据输出拼接成 bs*num_gpus的 batch_size的张量
last_outs = tf.concat(tower_outs, axis=0)
  1. 计算梯度平均的函数
def average_gradients(tower_grads):
    average_grads = []
    for grad_and_vars in zip(*tower_grads):
        grads = []
        for g, _ in grad_and_vars:
            expend_g = tf.expand_dims(g, 0)
            grads.append(expend_g)
        grad = tf.concat(grads, 0)
        grad = tf.reduce_mean(grad, 0)
        v = grad_and_vars[0][1]
        grad_and_var = (grad, v)
        average_grads.append(grad_and_var)
    return average_grads
  1. 可根据运行需求而扩大显存占用
tfconfig = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
tfconfig.gpu_options.allow_growth = True
totsess = tf.Session(config=tfconfig, graph=tf.get_default_graph())

参考文章:
tensorflow显存管理、tensorflow使用多个gpu训练
tensorflow:cifar10_multi_gpu_train.py 示例代码解析
tensorflow 多GPU编程指南

上一篇 下一篇

猜你喜欢

热点阅读