如何解决数组索引错误的索引过多

数据挖掘 喀拉斯 图像分类 索引
2022-02-25 22:15:58

我正在 Keras 中执行二进制分类并尝试绘制 ROC 曲线。当我尝试计算 fpr 和 tpr 指标时,出现“数组索引过多”错误。这是我的代码:

#declare the number of classes
num_classes=2
#predicted labels
y_pred = model.predict_generator(test_generator, nb_test_samples/batch_size, workers=1)
#true labels
Y_test=test_generator.classes
#print the predicted and true labels
print(y_pred)
print(Y_test)
'''y_pred float32 (624,2) array([[9.99e-01  2.59e-04],
                                 [9.97e-01  2.91e-03],...'''

'''Y_test int32 (624,) array([0,0,0,...,1,1,1],dtype=int32)'''

#reshape the predicted labels and convert type
y_pred = y_pred.argmax(axis=-1)
y_pred = y_pred.astype('int32')

#plot ROC curve
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(num_classes):
    fpr[i], tpr[i], _ = roc_curve(Y_test[:,i], y_pred[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])
fig=plt.figure(figsize=(15,10), dpi=100)
ax = fig.add_subplot(1, 1, 1)
# Major ticks every 0.05, minor ticks every 0.05
major_ticks = np.arange(0.0, 1.0, 0.05)
minor_ticks = np.arange(0.0, 1.0, 0.05)
ax.set_xticks(major_ticks)
ax.set_xticks(minor_ticks, minor=True)
ax.set_yticks(major_ticks)
ax.set_yticks(minor_ticks, minor=True)
ax.grid(which='both')
lw = 1 
plt.plot(fpr[1], tpr[1], color='red',
         lw=lw, label='ROC curve (area = %0.4f)' % roc_auc[1])
plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristics')
plt.legend(loc="lower right")
plt.show()

y-pred 和 Y_test 的形状是:

y_pred float32 (624,2) 数组([[9.99e-01 2.59e-04], [9.97e-01 2.91e-03],...

Y_test int32 (624,) array([0,0,0,...,1,1,1],dtype=int32)

1个回答

您的代码在两个地方被破坏。

第一个是因为您从y_pred. 线

y_pred = y_pred.argmax(axis=-1)

将您的预测向量重塑(624,)为与您的类向量匹配。因此,当您稍后尝试对数组进行切片时,y_pred[:,i]它会吠叫,因为您不再有第二维。这也不是您真正想要的行为,因为该roc_curve函数对您的模型产生的确切类概率感兴趣!

第二个是出于同样的原因,试图索引一维 numpy 数组的第二维,但用于Y_test向量。

因此,如果您有兴趣通过将每个类视为正类来捕获这两个类的 TPR/FPR,则需要删除这些行

#reshape the predicted labels and convert type
y_pred = y_pred.argmax(axis=-1)
y_pred = y_pred.astype('int32')

并且您需要将 for 循环的第一行更改为:

fpr[i], tpr[i], _ = roc_curve(Y_test, y_pred[:, i])

希望这可以帮助