pytorch 模型保存方式(.pt, .pth, .pkl)

2020-05-11  本文已影响0人  京漂的小程序媛儿

模型不同后缀名的区别

经常会看到后缀名为.pt, .pth, .pkl的pytorch模型文件,其实它们并不是在格式上有区别,只是后缀不同而已(仅此而已),在用torch.save()函数保存模型文件时,各人有不同的喜好,有些人喜欢用.pt后缀,有些人喜欢用.pth或.pkl.用相同的torch.save()语句保存出来的模型文件没有什么不同。

在pytorch官方的文档/代码里,有用.pt的,也有用.pth的。一般惯例是使用.pth,但是官方文档里貌似.pt更多,而且官方也不是很在意固定用一种。

模型保存与调用方式一:

只保存模型参数,不保存模型结构

保存:

torch.save(model.state_dict(), mymodel.pth)#只保存模型权重参数,不保存模型结构

调用:

model = My_model(*args, **kwargs)  #这里需要重新模型结构,My_model

model.load_state_dict(torch.load(mymodel.pth))#这里根据模型结构,调用存储的模型参数

model.eval()

模型保存与调用方式二:

保存整个模型,包括模型结构+模型参数

保存:

torch.save(model, mymodel.pth)#保存整个model的状态

调用:

model=torch.load(mymodel.pth)#这里已经不需要重构模型结构了,直接load就可以

model.eval()

举个例子

Bert模型结构,仅保存模型参数

保存时:

torch.save(model.state_dict(), file_name) # 这个model是已经训练好的模型

调用时:

model.load_state_dict(torch.load(file_name, map_location=device if device =='cpu' else "{}:{}".format(device, 0)))  # 这个model是Bert

上一篇 下一篇

猜你喜欢

热点阅读