如何在不平衡类数据集中获得正确的混淆矩阵?

数据挖掘 Python 分类 scikit-学习 逻辑回归
2022-02-14 01:53:20

我创建了 3 个类的两个模拟随机数据集。数据集之间的唯一区别是类的频率。

Dataset A: (Class 0 = 300, Class 1 =200, Class 2 = 500)
Dataset B: (Class 0 = 500, Class 1 =500, Class 2 = 500)

两者都是随机数据集,所以我应该期望从逻辑回归模型中混淆每个具有相同频率的类。这意味着在归一化混淆矩阵中,我应该期望所有三个类之间的混淆比例相等。

Confusion matrix of Dataset A

在此处输入图像描述

Confusion matrix of Dataset B 

在此处输入图像描述

我对数据集 A 的期望与数据集 B 相同。但我无法实现。为什么?我在 python 中使用以下命令来运行逻辑回归模型。

log_reg_model = LogisticRegression(C=1,penalty='l1',multi_class='ovr',class_weight='balanced',solver='liblinear')
pipe=Pipeline([('StandardScaler',StandardScaler()), ('logistic_regression',log_reg_model)])

编辑:我正在以下 Dropbox 链接中上传这两个数据集。 https://www.dropbox.com/sh/pkiapvqy3k3f12v/AADpeBJ0XTWA2v9MCjALBcexa?dl=0 第一列是索引,第二列是类id,第三到第五列是类特征。

3个回答

如果您的数据集是随机的(类别和预测变量之间没有真正的联系),那么“正确”模型是一个常数:在 (A) 中,预测概率应该大致为0.3,0.2,0.5, 而在 (B) 中它们应该是0.33,0.33,0.33. 然后在制作硬分类器时,在 (A) 中,最大概率几乎总是第三类,而在 (B) 中,每个类应该被大致相等地随机预测。当然会有一些偏差,因为您的模型将尝试在其随机选择的训练集中学习一些模式,但您应该期待与您所展示的非常相似的东西。

在不平衡的数据集中,您可能不想只选择最有可能的类别。如果你有一组更具预测性的特征可能没问题,因为模型可能能够学到足够的知识来克服先验概率,但是在这里有了不可预测的特征,就无法学到任何新东西。

您可以查看 SMOTE 和 ADAYSN 技术。这将帮助您通过创建合成数据来减少数据集中的不平衡

https://medium.com/coinmonks/smote-and-adasyn-handling-imbalanced-data-set-34f5223e167

我假设,您的重点是预测准确性而不是可解释性?

因此,由于存在类不平衡,您可以做两件事:

  1. 根据其他用户的建议,您可以使用 SMOTE 或任何技术。
  2. 使用在处理类不平衡方面更稳健的非参数方法。

我尝试Random Forest在您的数据上使用,分类结果已经很有希望,无需任何参数调整。我使用的R代码:

library(randomForest)

data_A$V2 = factor(data_A$V2, levels = c(0, 1,2))

set.seed(4)
classifier_random <- randomForest(V2~V3+V4+V5, data=data_A, ntree=500)
pred_forest <- predict(classifier_random, data_A[,c('V3','V4','V5')])

table(data_A$V2, pred_forest)

pred_forest
      0   1   2
  0 297   0   3
  1  13 180   7
  2   9   7 484

如果我以 80:20 的比例将它分成训练集和测试集怎么办?

smp_size <- floor(0.80 * nrow(data_A))

set.seed(4)
train_ind <- sample(seq_len(nrow(data_A)), size = smp_size)

train <- data_A[train_ind, ]
test <- data_A[-train_ind, ]

classifier_random <- randomForest(V2~V3+V4+V5, data=train, ntree=500)
pred_forest <- predict(classifier_random, test[,c('V3','V4','V5')])

table(test$V2, pred_forest)

pred_forest
     0  1  2
  0 17 11 33
  1 14  5 24
  2 16 13 67

错误分类更多,但它仍然有效,也没有任何调整。

其中的代码R可能对您并不完全有益,但我希望您能理解我在这里试图传达的观点。