Keras 实现 precision、recall、f1
2019-10-30 本文已影响0人
momo1023
使用 Keras 实现 precision、recall、f1
函数定义如下:
import keras
from keras import backend as K
def precision(y_true, y_pred):
# Calculates the precision
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
precision = true_positives / (predicted_positives + K.epsilon())
return precision
def recall(y_true, y_pred):
# Calculates the recall
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
recall = true_positives / (possible_positives + K.epsilon())
return recall
def fbeta_score(y_true, y_pred, beta=1):
# Calculates the F score, the weighted harmonic mean of precision and recall.
if beta < 0:
raise ValueError('The lowest choosable beta is zero (only precision).')
# If there are no true positives, fix the F score at 0 like sklearn.
if K.sum(K.round(K.clip(y_true, 0, 1))) == 0:
return 0
p = precision(y_true, y_pred)
r = recall(y_true, y_pred)
bb = beta ** 2
fbeta_score = (1 + bb) * (p * r) / (bb * p + r + K.epsilon())
return fbeta_score
def fmeasure(y_true, y_pred):
# Calculates the f-measure, the harmonic mean of precision and recall.
return fbeta_score(y_true, y_pred, beta=1)
在 Keras 网络模型使用方式:
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(),
metrics=['accuracy', precision, recall, fmeasure])
训练:
%%time
history = model.fit(x_train, y_train, batch_size=64, epochs=5, validation_split=0.1)
训练输出:
Train on 22500 samples, validate on 2500 samples
Epoch 1/5
22500/22500 [==============================] - 91s 4ms/sample - loss: 0.4759 - acc: 0.7517 - precision: 0.7494 - recall: 0.6967 - fmeasure: 0.6917 - val_loss: 0.3308 - val_acc: 0.8664 - val_precision: 0.8631 - val_recall: 0.8118 - val_fmeasure: 0.8346
Epoch 2/5
22500/22500 [==============================] - 89s 4ms/sample - loss: 0.2808 - acc: 0.8901 - precision: 0.8908 - recall: 0.8930 - fmeasure: 0.8894 - val_loss: 0.2861 - val_acc: 0.8876 - val_precision: 0.8827 - val_recall: 0.8956 - val_fmeasure: 0.8874
Epoch 3/5
22500/22500 [==============================] - 95s 4ms/sample - loss: 0.2016 - acc: 0.9294 - precision: 0.9273 - recall: 0.9329 - fmeasure: 0.9288 - val_loss: 0.2807 - val_acc: 0.8940 - val_precision: 0.8985 - val_recall: 0.8894 - val_fmeasure: 0.8918
Epoch 4/5
22500/22500 [==============================] - 112s 5ms/sample - loss: 0.1448 - acc: 0.9546 - precision: 0.9534 - recall: 0.9568 - fmeasure: 0.9541 - val_loss: 0.2996 - val_acc: 0.8852 - val_precision: 0.8991 - val_recall: 0.8697 - val_fmeasure: 0.8816
Epoch 5/5
22500/22500 [==============================] - 90s 4ms/sample - loss: 0.0917 - acc: 0.9755 - precision: 0.9738 - recall: 0.9772 - fmeasure: 0.9751 - val_loss: 0.3276 - val_acc: 0.8888 - val_precision: 0.8770 - val_recall: 0.8925 - val_fmeasure: 0.8810
CPU times: user 22min 38s, sys: 3min, total: 25min 39s
Wall time: 7min 58s