多卡训练的数据并行

2018-08-02  本文已影响0人  飞奔的卤蛋

最近在做多卡的实验,当然是使用最新的TensorFlow dataset API。在思考如何使每个卡取不同的数据,同时尽可能的提速,在论坛搜索了一下,思考有如下三种思路:

1、使用Dataset.batch()构造大batch的dataset,如4卡每卡batch size=6,那么就batch(24) 。然后在Iterator.get_next()之后tf.split(..., self.num_gpus),让每卡分到不同的batch。这应该是最简单的思路,不过split应该会降低速度。

2、用Dataset.batch()构造小batch的dataset,如每卡batch size=6,那么batch(6),然后在每卡上Iterator.get_next()。需要区分的是,如果Iterator.get_next()放在for i in range(num_gpus)之前,那么每卡读的batch应该是一样的。因此这种方法是指Iterator.get_next()放在循环里面。

3、创建多个iterator,每个GPU一个。在pipeline中使用dataset.Shard()对数据进行分片,请注意,此方法将消耗主机上的更多资源,因此可能需要减少buffer sizes 和degrees of parallelism。样例如下:

def input_fn(tfrecords_dirpath, num_gpus, batch_size,

            num_epochs, gpu_device, gpu_index):

    tfrecord_filepaths = tf.data.Dataset.list_files('{}/*.tfrecord'.format(tfrecords_dirpath))

    dataset = tf.data.TFRecordDataset(tfrecord_filepaths, num_parallel_reads= int(64 / num_gpus))

    dataset = dataset.shard(num_gpus, gpu_index)

    # use fused operations (shuffle_and_repeat, map_and_batch)

    dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(10000, num_epochs))

    dataset = dataset.apply(tf.contrib.data.map_and_batch(lambda x: parse_record(x), batch_size))

    # stage batches for processing by loading them pre-emptively on the GPU

    dataset = dataset.apply(tf.contrib.data.prefetch_to_device(gpu_device))

    iterator = dataset.make_one_shot_iterator()

    images_batch, labels_batch = iterator.get_next()

    return images_batch, labels_batch        

# create a separate inference graph in every GPU

gpu_devices = ['/gpu:{}'.format(i) for i in range(num_gpus)]

with tf.variable_scope(tf.get_variable_scope()):

    for i, gpu_device in enumerate(gpu_devices):

        # create a dataset and iterator per GPU

        image_batch, label_batch = input_fn(tfrecords_dirpath, num_gpus, batch_size_per_tower,

                                            num_epochs, gpu_device, i)

        with tf.device(gpu_device):

            with tf.name_scope('{}_{}'.format('tower', i)) as scope:

                # run inference and compute tower losses

                ...

上一篇 下一篇

猜你喜欢

热点阅读