我使用自定义参数构建了一个 DecisionTreeClassifier,以尝试了解修改它们会发生什么以及最终模型如何对 iris 数据集的实例进行分类。现在我的任务是创建一个 ROC 曲线,轮流将每个类视为正数(这意味着我需要在最终图表中创建 3 条曲线)。为此,我需要实例化一个 OnevsRestClassifier 并将先前的分类器作为参数传递,因此它会自动识别我修改的参数(例如类的权重)。这是我当前的代码:
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix
import numpy as np
import graphviz
iris = load_iris()
X = iris.data
y = iris.target
# Binarize the output
y = label_binarize(y, classes=[0, 1, 2])
n_classes = y.shape[1]
# My DecisionTreeClassifier
clf = tree.DecisionTreeClassifier(criterion="entropy",random_state=300,min_samples_leaf=5,
class_weight={0:1,1:10,2:10})
np.random.seed(0)
indices = np.random.permutation(len(iris.data))
indices_training=indices[:-10]
indices_test=indices[-10:]
iris_X_train = iris.data[indices_training]
iris_y_train = iris.target[indices_training]
iris_X_test = iris.data[indices_test]
iris_y_test = iris.target[indices_test]
# Training
clf = clf.fit(iris_X_train, iris_y_train)
# Test
predicted_y_test = clf.predict(iris_X_test)
print(confusion_matrix(iris_y_test, predicted_y_test))
print("Predictions:")
print(predicted_y_test)
print("True classes:")
print(iris_y_test)
# Learn to predict each class against the other
classifier = OneVsRestClassifier(clf)
# Train
classifier = classifier.fit(iris_X_train, iris_y_train)
# Test
y_score = classifier.predict_proba(iris_X_test)
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(iris_y_test, y_score)
roc_auc[i] = auc(fpr[i], tpr[i])
我的问题是我收到错误:
类标签 2 不存在。
在线:classifier = classifier.fit(iris_X_train, iris_y_train)
这是我训练新分类器的时候,我不明白为什么。我检查了 iris 数据集,有三个类,所以标签 2 应该对应于virginica,对吧?