DL

Keras 实现 precision、recall、f1

2019-10-30  本文已影响0人  momo1023

使用 Keras 实现 precision、recall、f1

函数定义如下:

import keras
from keras import backend as K

def precision(y_true, y_pred):
    # Calculates the precision
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

def recall(y_true, y_pred):
    # Calculates the recall
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def fbeta_score(y_true, y_pred, beta=1):
    # Calculates the F score, the weighted harmonic mean of precision and recall.
    if beta < 0:
        raise ValueError('The lowest choosable beta is zero (only precision).')
 
    # If there are no true positives, fix the F score at 0 like sklearn.
    if K.sum(K.round(K.clip(y_true, 0, 1))) == 0:
        return 0

    p = precision(y_true, y_pred)
    r = recall(y_true, y_pred)
    bb = beta ** 2
    fbeta_score = (1 + bb) * (p * r) / (bb * p + r + K.epsilon())
    return fbeta_score

def fmeasure(y_true, y_pred):
    # Calculates the f-measure, the harmonic mean of precision and recall.
    return fbeta_score(y_true, y_pred, beta=1)

在 Keras 网络模型使用方式:

model.compile(optimizer=keras.optimizers.Adam(),
                 loss=keras.losses.BinaryCrossentropy(),
                 metrics=['accuracy', precision, recall, fmeasure])

训练:

%%time
history = model.fit(x_train, y_train, batch_size=64, epochs=5, validation_split=0.1)

训练输出:

Train on 22500 samples, validate on 2500 samples
Epoch 1/5
22500/22500 [==============================] - 91s 4ms/sample - loss: 0.4759 - acc: 0.7517 - precision: 0.7494 - recall: 0.6967 - fmeasure: 0.6917 - val_loss: 0.3308 - val_acc: 0.8664 - val_precision: 0.8631 - val_recall: 0.8118 - val_fmeasure: 0.8346
Epoch 2/5
22500/22500 [==============================] - 89s 4ms/sample - loss: 0.2808 - acc: 0.8901 - precision: 0.8908 - recall: 0.8930 - fmeasure: 0.8894 - val_loss: 0.2861 - val_acc: 0.8876 - val_precision: 0.8827 - val_recall: 0.8956 - val_fmeasure: 0.8874
Epoch 3/5
22500/22500 [==============================] - 95s 4ms/sample - loss: 0.2016 - acc: 0.9294 - precision: 0.9273 - recall: 0.9329 - fmeasure: 0.9288 - val_loss: 0.2807 - val_acc: 0.8940 - val_precision: 0.8985 - val_recall: 0.8894 - val_fmeasure: 0.8918
Epoch 4/5
22500/22500 [==============================] - 112s 5ms/sample - loss: 0.1448 - acc: 0.9546 - precision: 0.9534 - recall: 0.9568 - fmeasure: 0.9541 - val_loss: 0.2996 - val_acc: 0.8852 - val_precision: 0.8991 - val_recall: 0.8697 - val_fmeasure: 0.8816
Epoch 5/5
22500/22500 [==============================] - 90s 4ms/sample - loss: 0.0917 - acc: 0.9755 - precision: 0.9738 - recall: 0.9772 - fmeasure: 0.9751 - val_loss: 0.3276 - val_acc: 0.8888 - val_precision: 0.8770 - val_recall: 0.8925 - val_fmeasure: 0.8810
CPU times: user 22min 38s, sys: 3min, total: 25min 39s
Wall time: 7min 58s
上一篇下一篇

猜你喜欢

热点阅读