Keras如何保存训练模型
2019-08-15 本文已影响0人
一位学有余力的同学
一、保存模型
方法一:通过Checkpoint保存
在Keras中有ModelCheckpoint函数,调用该函数可以将每个epoch后的模型进行保存。详见官方文档。具体的使用方法如下:
from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(filepath,
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
period=1)
参数:
- filepath: 保存模型的路径
- monitor: 被监测的对象,比如
acc
,loss
,val_acc
,... - verbose: 冗余的,如果想加上进度条,就是1,如果不想,就是0
- set_best_only = True: 只保存最好的模型
- save_weights_only = False: 如果 True,那么只有模型的权重会被保存
model.save_weights(filepath)
, 否则的话,整个模型会被保存model.save(filepath)
- mode = 'auto':
auto
max
和min
,如果使监测acc
,就是max
,如果是监测loss
,就是min
,auto
就i会从监测值中自己判断 - period: 每个检查点之间的间隔(训练轮数)
方法二:通过save_model()保存
保存模型有两种形式,一种是将模型整个都保存下来,包括权重和结构;另一种是只保留权重。
1. 1保存整个模型
可以使用model.save(filepath)
将Keras的模型保存到HDF5文件中,该文件将包含:模型结构、模型权重、配置项(优化函数、优化器)和优化状态,允许准确地从上次结束地地方继续训练,详见官方文档。
from keras.model import load_model
model.save('my_model.h5')
model = load_model('my_model.h5')
1.2 保存部分模型
1.2.1 分别保存模型的权重和结构
我们可以使用to_json()
和to_yaml()
方法将模型结构保存到josn文件或者yaml文件中。
'''方法1:保存为json'''
json_string = model.to_json()
open('model_architecture_1.json', 'w').write(json_string) #重命名
#从json中读出数据
from keras.models import model_from_json
model = model_from_json(json_string)
'''方法2:保存为yaml'''
yaml_string = model.to_yaml()
open('model_arthitecture_2.yaml', 'w').write(yaml_string) #重命名
#从json中读出数据
from keras.models import model_from_yaml
model = model_from_yaml(yaml_string)
1.2.2 只保存模型权重
通过model.save_weights('my_model_weights.h5')
将权重保存在HDF5文件中。如果有可以实例化模型的代码。则可以将保存的权重加载到相同结构的模型中:
model.load_weights('my_model_weights.h5')
二、导入模型
如果我们想导入训练好的最好的模型来进行预测,最好使用方法一,将最好的模型保存下来然后导入进行预测,如果想接着上一次的模型继续训练,可以两种方法都可以。
from keras.models import load_model
model = load_model(filepath)
保存模型可能会出错的地方:
- filepath不会自己建立文件夹,例如
checkpointer = ModelCheckpoint(filepath='tmp\model.h5')
,如果同目录下没有tmp文件夹,程序将会出错,模型无法保存。 - 在进行训练的时候,需要把checkpointer加进去回调函数
callbacks
中,并用中括号扩起来,即model.fit(x, y, callbacks=[checkpointer])
附:保存模型图
我们可以通过model.summary()
将模型的结构打印出来,另外,我们可以通过plot_model(model, 'model_plot.png')
将模型的基本结构框图保存下来。另外你也可以使用keras.utils.vis_utils
模块将模型的详细结构框图保存下来。
from keras.utils import plot_model
plot_model(model, to_file='model.png')