模型保存

2021-05-25  本文已影响0人  三方斜阳

关于pytorch模型保存,在训练过程中常用,记录总结一下

# 保存整个网络
torch.save(net, PATH) 
# 保存网络中的参数, 速度快,占空间少
torch.save(net.state_dict(),PATH)
#--------------------------------------------------
#针对上面一般的保存方法,加载的方法分别是:
model_dict=torch.load(PATH)
model_dict=model.load_state_dict(torch.load(PATH))

如何保存和重新加载微调模型,通常需要保存三种文件类型才能重新加载经过微调的模型:

模型权重文件:pytorch_model.bin
配置文件:config.json
词汇文件:vocab.txt

如果使用这些默认文件名保存模型,则可以使用from_pretrained()方法重新加载模型和tokenizer。

使用实例1:

from transformers import WEIGHTS_NAME, CONFIG_NAME
output_dir = 'models/'
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(output_dir, CONFIG_NAME)
#如果我们有一个分布式模型,只保存封装的模型
#它包装在PyTorch DistributedDataParallel或DataParallel中
model_to_save = model.module if hasattr(model, 'module') else model

torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_dir)

加载模型:

model = BERT_BiLSTM_CRF.from_pretrained(output_dir, need_birnn=args.need_birnn, rnn_dim=args.rnn_dim).to(device)
tokenizer = BertTokenizer.from_pretrained(output_dir, do_lower_case=args.do_lower_case)  
***Add specific options if needed

使用实例2:

model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

2. 保存整个模型:

保存模型
torch.save(model, 'models/model.pkl')
加载模型
model = torch.load('models/model.pkl')

3. 踩坑记录

之前有个代码,训练之后直接预测结果没有问题,模型保存好之后,,重新开一个文件加载模型进行预测,结果非常差,一直以为问题出在了模型加载问题上,网上找了很多类似的错误,可能的原因如下:

但是逐一排查之后都不是这个原因,这个时候就不应该怀疑是模型加载的问题,因为加载之前可以通过打印出 model.stat_dict() 来查看是否加载正确,这时候要从 测评文件数据的问题来思考,最后发现是因为单独预测的时候加载的测评标准文件中有个对照数据不一样,所以导致,训练出来的模型在新的数据上测评结果很差。

上一篇下一篇

猜你喜欢

热点阅读