Keras Callbacks
2017-08-10 本文已影响0人
四碗饭儿
在每个training/epoch/batch结束时,如果我们想执行某些任务,例如模型缓存、输出日志、计算当前的auc等等,Keras中的callback就派上用场了。
Example 记录每个batch的损失函数值
import keras
# 定义callback类
class MyCallback(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.losses = []
return
def on_batch_end(self, batch, logs={}): # batch 为index, logs为当前batch的日志acc, loss...
self.losses.append(logs.get('loss'))
return
# 定义模型model
...
...
# 调用callback
cb = MyCallback()
# 训练模型
model.fit(x_train, y_train, batch_size=32, epochs=10, callbacks=[cb])
# 查看callback内容
cb.losses
如上述例子,我们可以继承keras.callbacks.Callback
来定义自己的callback,只需重写其中的6个方法即可
on_train_begin
on_train_end
on_epoch_begin
on_epoch_end
on_batch_begin
on_batch_end
可在这6个方法中定义自己想要的属性,通过self.model
可以访问模型本身,self.params
可以访问训练参数。
可能有用的属性
-
self.validation_data
validate数据集 -
self.validation_data[0]
为X -
self.validation_data[1]
为y