Pytorch Lightning系列 如何使用ModelChe
2021-08-30 本文已影响0人
四碗饭儿
在训练机器学习模型时,经常需要缓存模型。ModelCheckpoint
是Pytorch Lightning中的一个Callback,它就是用于模型缓存的。它会监视某个指标,每次指标达到最好的时候,它就缓存当前模型。Pytorch Lightning文档 介绍了ModelCheckpoint的详细信息。
我们来看几个有趣的使用示例。
示例1 注意,我们把epoch和val_loss信息也加入了模型名称。
>>> checkpoint_callback = ModelCheckpoint(
... monitor='val_loss', #我们想要监视的指标
... dirpath='my/path/', #模型缓存目录
... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}' # 模型名称
... )
示例2 这个使用例子非常像示例1,唯一的差别在于指标的名称是由我们自己指定的,而不是由Pytorch Lightning自动生成的 (auto_insert_metric_name=False
)。通过这样的方式,我们可以使用类似val/mrr
的指标名。从而统一tensorboard和pytorch lightning对指标的不同描述方式。
>>> checkpoint_callback = ModelCheckpoint(
... monitor='val/loss',
... dirpath='my/path/',
... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}', # 注意到val/loss变成了val_loss
... auto_insert_metric_name=False
... )
Pytorch Lightning把ModelCheckpoint当作最后一个CallBack,也就是它总是在最后执行。这一点在我看来很别扭。如果你在训练过程中想获得best_model_score或者best_model_path,它对应的是上一次模型缓存的结果,而并不是最新的模型缓存结果
self.trainer.checkpoint_callback.best_model_score