机器学习学社机器学习与数据挖掘机器学习和人工智能入门

机器学习中的混淆矩阵

2019-07-03  本文已影响0人  一枚不只关注技术的技术宅
混淆矩阵

让混淆矩阵不再混淆

混淆矩阵是用于总结分类算法性能的技术。如果每个类中的样本数量不等,或者数据集中有两个以上的类,则仅用分类准确率作为评判标准的话可能会产生误导。计算混淆矩阵可以让我们更好地了解分类模型的表现情况以及它所犯的错误的类型。

阅读这篇文章后你会了解到:

1.分类准确率及其局限性

分类准确率是预测正确的样本数与总样本数的比值

即:分类准确率 = 预测正确的样本数 / 总样本数

错误率即: 错误率 = 1 - 分类准确率

分类准确率看上去是一个不错的评判标准,但在实际当中往往却存在着一些问题。其主要问题在于它隐藏了我们需要的细节,从而阻碍我们更好地理解分类模型的性能。 有两个最常见的例子:

  1. 当我们的数据有2个以上的类时,如3个或更多类,我们可以得到80%的分类准确率,但是我们却不知道是否所有的类别都被预测得同样好,或者说模型是否忽略了一个或两个类;

  2. 当我们的每个类中的样本数并不接近时,我们可以达到90%或更高的准确率,但如果每100个记录中有90个记录属于一个类别,则这不是一个好分数,我们可以通过始终预测最常见的类值来达到此分数。如90个样本都属于类别1,则我们的模型只需要预测所有的样本都属于类别1,便可以达到90%或更高的准确率。

分类准确率可以隐藏诊断模型性能所需的详细信息,但幸运的是,我们可以通过混淆矩阵来进一步区分这些细节。

2.什么是混淆矩阵

混淆矩阵是对分类问题的预测结果的总结。使用计数值汇总正确和不正确预测的数量,并按每个类进行细分,这是混淆矩阵的关键所在。混淆矩阵显示了分类模型的在进行预测时会对哪一部分产生混淆。它不仅可以让您了解分类模型所犯的错误,更重要的是可以了解哪些错误类型正在发生。正是这种对结果的分解克服了仅使用分类准确率所带来的局限性。

3.怎么计算混淆矩阵

  1. 我们需要具有类别标签的测试数据集或验证数据集;

  2. 对测试数据集中的每一行进行预测;

  3. 从类别标签和预测结果我们可以得出:

    • 每个类别的正确预测数量;

    • 每个类的错误预测数;

  4. 将这些数字组织成表格或矩阵,如下所示:

    • 表格左边由上至下:矩阵的每一行对应一个预测的类;

    • 表格上部:矩阵的每列对应于实际的类;

    • 将正确和不正确分类的计数填入表中;

  5. 将类别的正确预测总数填进该类值的标签行和该类值的预测列

  6. 将类别的错误预测总数填进该类值的标签行和该类值的预测列。

该矩阵可用于易于理解的二类分类问题,但通过向混淆矩阵添加更多行和列,可轻松应用于具有3个或更多类值的问题。

让我们通过一个例子来解释如何建立混淆矩阵。

4.二类分类问题混淆矩阵的建立

让我们假设我们有一个两类分类问题,即预测照片内有男人还是女人。

我们有一个包含10个记录的测试数据集,如下是标签分类和模型的预测结果:

标签分类 预测结果

首先,我们来计算分类准确率:

准确率 = 所有正确预测的样本 / 所有样本 = 7 / 10 = 70%

接着,我们来建立混淆矩阵:

我们需要先算出每一个类当中预测正确的数量:

“男”类:3个

“女”类:4个

然后,我们计算每一个类当中预测错误的数量:

“男”类:2个

“女”类:1个

现在我们可以把结果填入到一个2X2的矩阵当中:

男(标签值) 女(标签值)
男(预测值) 3 1
女(预测值) 2 4

5.二类分类问题的特殊混淆矩阵

在二类分类问题中,我们经常寻求从正常数据中区分出具有特殊性质的数据,如患病与未患病。通过这种方式,我们可以将患病事件行指定为“Positive(阳性)”,将未患病事件行指定为“Negative(阴性)”。 然后,我们可以将预测患病指定为“True(真)”,将预测未患病事件指定为“False(假)”。

由此我们可以得到:

我们可以在混淆矩阵中总结如下:

患病(实际) 未患病(实际)
患病(预测) 真阳性 假阳性
未患病(预测) 假阴性 真阴性

从混淆矩阵当中,我们可以得到更高级的分类指标:Precision(准确率),Recall(召回率),Specificity(特异性),Sensitivity(灵敏度)。

现在我们已经掌握了如何去构造混淆矩阵,接下来我们看看如何使用Python来帮助我们完成这一任务。

6.利用Python中的Scikit-Learn构造混淆矩阵

Python中的scikit-learn库可以用于计算混淆矩阵。给定标签值以及模型预测结果,使用confusion_matrix()函数计算混淆矩阵并将结果作为数组返回。 然后,我们可以打印此数组并解释结果。

# 利用Python构造混淆矩阵的例子:
from sklearn.metrics import confusion_matrix

expected = [1, 1, 0, 1, 0, 0, 1, 0, 0, 0]
predicted = [1, 0, 0, 1, 0, 0, 1, 1, 1, 0]

results = confusion_matrix(expected, predicted)
print(results)

运行以上示例会输出以下形式的混淆矩阵:

[[4, 2], 
[1, 3]]

总结

在这篇文章中,我们了解到:

如果你任何问题的话,请在文章下面回复,我会尽可能的回答大家的问题。



想成为一名合格的数据科学家/AI工程师吗?关注“机器学习学社”获取每天一份的新鲜咨询,为你开拓属于你自己的AI之路

扫码关注机器学习学社,更多资讯早知道,更多学习干货等你来拿
上一篇 下一篇

猜你喜欢

热点阅读