分类问题中混淆矩阵的目的是什么?

数据挖掘 分类 交叉验证 公制 混淆矩阵 网格搜索
2022-02-22 00:49:56

我正在学习机器学习。经过一些研究,我了解到分类问题的典型工作流程(在准备好数据之后)如下:

  1. 在测试、训练和验证集中拆分数据
  2. 训练模型
  3. 生成混淆矩阵
  4. 分析指标:accuracy、precision、recall 和 f1
  5. 根据我决定优化的指标调整超参数。

我的问题是:为什么我们需要混淆矩阵?考虑到我们要解决的问题类型,我们不应该已经知道需要优化什么指标吗?

我问这个是因为,据我了解,如果我们有足够的计算能力,我们基本上可以通过应用网格搜索(基本上包括每个调整参数的交叉验证)将步骤 2 和 5 分组,该网格搜索为输入要测量的指标。这意味着您需要事先了解指标,而且您无法获得混淆矩阵。

提前感谢您的回复。

2个回答

好的,那么让我为您回答其中一些问题:

  • 混淆矩阵的目的是什么?

混淆矩阵只是一种视觉帮助,可以帮助您更好地解释模型的性能。这是一种以图形方式可视化真阳性 (TP)、假阳性 (FP)、真阴性 (TN) 和假阴性 (FN) 的方法。当您处理大量不同的类时,混淆矩阵变得越来越有用。它可以让您深入了解模型的运行情况。假设您正在训练一个图像检测分类器。很高兴知道您的模型混淆了狗和狼,但没有混淆猫和蛇。

混淆矩阵的另一个目的是定义一个相关的成本矩阵。在我的例子中,狗和狼之间的混淆可能是可以理解的,这不应该意味着你的模型不擅长它的工作。但是,如果在不应该混淆的类之间混淆,则应该在您的性能指标中正确表示。

这是一个很好的博客,详细介绍了这些概念:https ://medium.com/@inivikrant/confusion-cost-matrix-helps-in-calculating-the-accuracy-cost-and-various-other-measurable-a725fb6b54e1

  • 考虑到我们要解决的问题类型,我们不应该已经知道需要优化什么指标吗?

在这里,您混淆了两件事。一方面,是的,您应该提前知道要优化哪个指标(即准确度、精确度、召回率等),但这并不意味着您提前知道该指标的价值。如果你对超参数调优进行简化,大致是这样的:

  1. 使用超参数训练模型MH
  2. 评估模型PM
  3. 选择新的超参数并重复步骤 1 和 2H
  4. 选择具有超参数的模型 ,以便优化MHP

如果您知道 TP、FP、TN、FP,则可以计算准确度、精度、召回率或 F1 分数(有关更多信息,请参阅此链接)。所以从技术上讲,您不必创建混淆矩阵本身,但您肯定需要计算 TP、FP、TN、FP 来评估模型的性能。

  • 如果我们有足够的计算能力,我们基本上可以将步骤 2 和 5 分组

只有为每组超参数计算模型的性能,才能优化超参数。您可以跳过第 3 步,因为它在技术上不会影响您的训练过程。它只会帮助您更好地了解正在发生的事情。但是你绝对不能跳过第 4 步。

为上述答案加 2 分——

  • 混淆矩阵为商务人士提供了更好的比较图。例如,如果您通知您的企业F1 分数为 0.9,那么对他来说使用较少。
    但他会喜欢你说 - 模型将错过 100 例癌症病例中的 9 例,并在没有时将 10000 例中的​​ 50 例报告为癌症。

  • 当你有超过 2 个类时,CM 会给出模型所犯的学习错误的想法。例如,在 Fashion MNIST 数据中,我们可以观察到模型在 Shirt 和 Coat 之间存在混淆。您可以相应地进行调整。见下图

图片来源- https://www.kaggle.com/fuzzywizard/fashion-mnist-cnn-keras-accuracy-93 混淆矩阵