如何为交叉验证的每一折获得多类分类的敏感性和特异性?

数据挖掘 交叉验证 多类分类
2022-02-18 15:08:49

我正在研究一个由 4 个类组成的多类分类。我正在对其应用 5 折交叉验证,并希望获得每个折痕的敏感性(召回)和特异性得分。

我发现使用cross_validate函数,我可以为它提供每个折叠的评分参数列表。

    scoring = {'accuracy' : make_scorer(accuracy), 
               'precision' : make_scorer(precision_score),
               'recall' : make_scorer(recall_score), 
               'f1_score' : make_scorer(f1_score)}

    cross_validate(neural_network, data, y, cv=5,scoring=scoring)

但是,这会产生错误,因为这些函数(精度除外)仅用于二分类而不用于多分类。

因此,我决定为灵敏度得分和特异性得分创建自己的函数,返回 4 个单独值的平均值(每个类别 1 个)。我返回它们的平均值,而不是单独的 4 个值,因为不允许返回多个值的记分器函数。这对我来说很好,因为我只想要他们的意思。

这是我尝试过的:

    def sensitivity(y_true,y_pred):
        cm=confusion_matrix(y_true, y_pred)
        FP = cm.sum(axis=0) - np.diag(cm)  
        FN = cm.sum(axis=1) - np.diag(cm)
        TP = np.diag(cm)
        TN = cm.sum() - (FP + FN + TP)
        Sensitivity = TP/(TP+FN)    
        return np.mean(Sensitivity)

    def specificity(y_true,y_pred):
        cm=confusion_matrix(y_true, y_pred)
        FP = cm.sum(axis=0) - np.diag(cm)  
        FN = cm.sum(axis=1) - np.diag(cm)
        TP = np.diag(cm)
        TN = cm.sum() - (FP + FN + TP)
        Specificity = TN/(TN+FP)    
        return np.mean(Specificity)


    scoring = {'sensitivity' : make_scorer(sensitivity),
               'specificity' : make_scorer(specificity)}

    cross_validate(neural_network, data, y, cv=5,scoring=scoring)

但它仍然抛出同样的错误:

ValueError: Classification metrics can't handle a mix of multilabel-indicator and multiclass targets

我不知道这里有什么不工作。我只想要每个类别的灵敏度平均值和每个类别的特异性平均值,对于 5 折中的每一个。

我的方法有什么问题,还有更简单的方法吗?

2个回答

我认为这个错误来自confusion_matrix(),这里我们有三个“types_of_target”:多类,多标签指标,连续多输出。

比如np.array([1, 0, 2])是多类,它的one-hot-encodingnp.array([[0,1,0],[1,0,0],[0,0,1]])是multilabel-indicator,我们预测np.array([0.3,0.4,0.3],[0.7,0.2,0.1],[0.1,0.1,0.8])的是continuous-multioutput。

的输入confusion_matrix必须是“多类”类型。

我觉得你可以试试

confusion_matrix(y_true.argmax(axis=1),np.rint(y_pred).argmax(axis=1))

通过将 y_true 从 multilabel-indicator 转换为 multiclass,并将 y_pred 从 probs(continuous-multioutput) 转换为 one-hot(multilabel-indicator) 然后是 multiclass。

快速而肮脏的解决方案:(鉴于您不想“手工”完成)

只需使用分类报告,然后只需选择您想要的平均指标即可。