(二)sklearn.metrics模型性能评价
2019-11-01 本文已影响0人
神经网络爱好者
目录
1、模型评价指标
2、sklearn.metrics.classification_report的使用
3、混淆矩阵的使用
1、模型评价指标
准确率(accuracy):分对的样本数除以所有的样本数 。
准确率一般用来评估模型的全局准确程度,不能包含太多信息,全面评价一个模型,其中混淆矩阵是一个常用的手段。
![](https://img.haomeiwen.com/i8197259/0d2c6eae176ebcdc.png)
precision(查准率):预测为正的样本当中有多少预测准确了。
recall(查全率):真正为正的样本当中有多少被预测准确了了。
对于实际多分类问题:
如果类别A的P很低,R很高,则表示有其他类别的样本大量被预测为A,可能是类别之间很相似,也可能是数据集分错了;
如果类别A的P很高,R很低,则表示类别A的样本被预测为其他类了。
因此为了综合P和R,可以使用Fα-Score:
如果小于1则侧重P,大于1侧重R,等于1相当于调和平均数,F1-Score.
2、sklearn.metrics.classification_report的使用
classfication_report
函数以文本的方式给出了分类结果的主要预测性能指标。其原型为:
sklearn.metrics.classification_report(y_true, y_pred, labels=None, target_names=None,
sample_weight=None, digits=2)
返回值:一个格式化的字符串,给出了分类评估报告。
参数:
-
y_true
:真实的标记集合。 -
y_pred
:预测的标记集合。 -
labels
:一个列表,指定报告中出现哪些类别。 -
target_names
:一个列表,指定报告中类别对应的显示出来的名字。 -
digits
:用于格式化报告中的浮点数,保留几位小数。 -
sample_weight
:样本权重,默认每个样本的权重为 1。
from sklearn.metrics import classification_report
y_true = [0, 1, 2, 2, 2]
y_pred = [0, 1, 2, 2, 1]
label = [0, 1, 2]
target_names = ['class 0', 'class 1', 'class 2']
print((classification_report(y_true, y_pred, labels=label, target_names=target_names)))
此时的程序输出如下:
precision recall f1-score support
class 0 1.00 1.00 1.00 1
class 1 0.50 1.00 0.67 1
class 2 1.00 0.67 0.80 3
avg / total 0.90 0.80 0.81 5
如果使label=[0,1]
,则得到的输出如下:
precision recall f1-score support
class 0 1.00 1.00 1.00 1
class 1 0.50 1.00 0.67 1
avg / total 0.75 1.00 0.83 2
3、混淆矩阵的使用
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
y_true = [0, 1, 2, 2, 2]
y_pred = [0, 1, 2, 2, 1]
target_names = ['class 0', 'class 1', 'class 2']
#混淆矩阵
sns.set()
f,ax = plt.subplots()
colormap = sns.cm.rocket_r#sns.cubehelix_palette(as_cmap=True,reverse=False)
C2 = confusion_matrix(y_true, y_pred)
C2 = pd.DataFrame(C2, index=target_names, columns=target_names)
sns.heatmap(C2, annot=True, ax=ax, cmap=colormap) #画热力图
ax.set_title('confusion matrix') #标题
ax.set_xlabel('predict') #x轴
ax.set_ylabel('true') #y轴
plt.show()
得到的图片如下:
![](https://img.haomeiwen.com/i8197259/70ef8994f229d7de.png)