keras 单机多卡情况下的知识蒸馏(knowledge dis
2022-02-16 本文已影响0人
FreeTheWorld
参考keras官方文档给出了知识蒸馏的简单模型的写法,但是在单机多卡或多机多卡开启了mirrored_strategy情况下该怎么写呢?
亲测在tensorflow2.4环境,成功的在开启mirrored_strategy下跑通模型,主要代码如下,重点讲解见注释。
Distiller模型
class Distiller(keras.Model):
def __init__(self, student, teacher):
super(Distiller, self).__init__()
self.teacher = teacher
self.student = student
def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn, alpha=0.1, temperature=3):
super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
self.alpha = alpha
self.temperature = temperature
def train_step(self, data):
if len(data) == 3:
x, y, sample_weight = data
else:
sample_weight = None
x, y = data
teacher_predictions = self.teacher(x, training=False)
# 注意batch_size = total_batch_size / gpu个数
with tf.GradientTape() as tape:
student_predictions = self.student(x, training=True)
# Compute losses
student_loss = self.student_loss_fn(y, student_predictions) # shape = (batch_size,1)
distillation_loss = self.distillation_loss_fn(teacher_predictions, student_predictions, self.temperature) # shape = (batch_size,1)
loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss # shape = (batch_size,1)
loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
# Compute gradients
trainable_vars = self.student.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update the metrics configured in `compile()`.
self.compiled_metrics.update_state(y, student_predictions)
# Return a dict of performance
results = {m.name: m.result() for m in self.metrics}
results.update(
{"student_loss": student_loss, "distillation_loss": distillation_loss}
)
return results
def test_step(self, data):
# Unpack the data
x, y, sample_weight = data
# Compute predictions
student_out = self.student(x, training=False)
y_prediction = student_out[0]
# Calculate the loss
student_loss = self.student_loss_fn(y, y_prediction)
# Update the metrics.
self.compiled_metrics.update_state(y, y_prediction)
# Return a dict of performance
results = {m.name: m.result() for m in self.metrics}
results.update({"student_loss": student_loss})
return results
train 训练过程
def train(batch, train_dataset, test_dataset):
mirrored_strategy = tf.distribute.MirroredStrategy(devices=gpus)
batch_size = batch * mirrored_strategy.num_replicas_in_sync
logging.info("batch size: %d, %d gpus", batch_size, mirrored_strategy.num_replicas_in_sync)
with mirrored_strategy.scope():
teacher = teacher_model()
teacher.fit(train_dataset,
validation_data=test_dataset,
validation_steps=4,
epochs=4)
with mirrored_strategy.scope():
student = student_model()
#这里是二分类loss,多分类可以换成softmax
# loss的定义都要放到 with mirrored_strategy.scope()里
binary_loss = tf.keras.losses.BinaryCrossentropy(from_logits=False, reduction=tf.keras.losses.Reduction.NONE) #reduction参数一定要加上
def distill_loss(teacher_predictions, student_predictions, sample_weight=None, temperature=None):
if temperature and temperature != 1:
teacher_predictions = tf.sigmoid(tf.math.log(teacher_predictions / (1 - teacher_predictions)) / temperature)
student_predictions = tf.sigmoid(tf.math.log(student_predictions / (1 - student_predictions)) / temperature)
per_example_loss = binary_loss(teacher_predictions, student_predictions, sample_weight=sample_weight)
return per_example_loss
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001, amsgrad=True),
metrics=get_metrics(),
student_loss_fn=binary_loss,
distillation_loss_fn=distill_loss,
alpha=ARGS.alpha,
temperature=ARGS.temperature)
distiller.fit(train_dataset,
validation_data=test_dataset,
validation_steps=4,
epochs=6)