混合精度训练

2020-01-07  本文已影响0人  顾北向南

原文来自于机器学习算法与自然语言处理公众号

混合精度训练

理论原理

权重备份(master weights)

损失放缩(loss scaling)

运算精度(precison of ops)

Pytorch

from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 这里是“欧一”,不是“零一”
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()

import torch
from apex import amp
model = ... 
optimizer = ...

#包装model和optimizer
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

for data, label in data_iter: 
    out = model(data) 
    loss = criterion(out, label) 
    optimizer.zero_grad() 
    
    #loss scaling,代替loss.backward()
    with amp.scaled_loss(loss, optimizer) as scaled_loss:   
        scaled_loss.backward() 
optimizer.step()

Tensorflow


os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
opt = tf.train.AdamOptimizer()

#add a line
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(
          opt,
          loss_scale='dynamic')
          
train_op = opt.miminize(loss)

opt = tf.keras.optimizers.Adam()

#add a line
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(
            opt,
            loss_scale='dynamic')
            
model.compile(loss=loss, optimizer=opt)
model.fit(...)

PaddlePaddle


export FLAGS_sync_nccl_allreduce=0
export FLAGS_eager_delete_tensor_gb=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

BERT_BASE_PATH="chinese_L-12_H-768_A-12"
TASK_NAME='XNLI'
DATA_PATH=/path/to/xnli/data/
CKPT_PATH=/path/to/save/checkpoints/

python -u run_classifier.py --task_name ${TASK_NAME} \
                   --use_fp16=true \  #!!!!!!add a line
                   --use_cuda true \
                   --do_train true \
                   --do_val true \
                   --do_test true \
                   --batch_size 32 \
                   --in_tokens false \
                   --init_pretraining_params ${BERT_BASE_PATH}/params \
                   --data_dir ${DATA_PATH} \
                   --vocab_path ${BERT_BASE_PATH}/vocab.txt \
                   --checkpoints ${CKPT_PATH} \
                   --save_steps 1000 \
                   --weight_decay  0.01 \
                   --warmup_proportion 0.1 \
                   --validation_steps 100 \
                   --epoch 3 \
                   --max_seq_len 128 \
                   --bert_config_path ${BERT_BASE_PATH}/bert_config.json \
                   --learning_rate 5e-5 \
                   --skip_steps 10 \
                   --num_iteration_per_drop_scope 10 \
                   --verbose true
上一篇下一篇

猜你喜欢

热点阅读