如何使用 scikitplot 构建提升图

数据挖掘 机器学习 Python
2022-03-07 23:36:13

我正在尝试使用 scikitplot 库构建提升图。

我得到以下任何人都可以指导错误。

请找到有错误的代码

代码:

import scikitplot as skplt
actual = df['Actual']
predicted = df['Pred']

skplt.metrics.plot_lift_curve(act,pdt)

Error:
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-38-f14ae9c16143> in <module>
      6 pdt = np.array(predicted)
      7 # plotLiftChart(actual,predicted)
----> 8 skplt.metrics.plot_lift_curve(act,pdt)
      9 # plt.show()

~\Anaconda3\lib\site-packages\scikitplot\metrics.py in plot_lift_curve(y_true, y_probas, title, ax, figsize, title_fontsize, text_fontsize)
   1192 
   1193     # Compute Cumulative Gain Curves
-> 1194     percentages, gains1 = cumulative_gain_curve(y_true, y_probas[:, 0],
   1195                                                 classes[0])
   1196     percentages, gains2 = cumulative_gain_curve(y_true, y_probas[:, 1],

IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed

请指导我在哪里做错了

1个回答

错误是没有阅读文档;)

scikitplot.metrics.plot_lift_curve(y_true, y_probas,...)

参数:

  • y_true(array-like, shape (n_samples)) – 真实(正确)目标值。
  • y_probas(array-like, shape (n_samples, n_classes)) – 分类器返回的每个类的预测概率。

第二个参数y_probas应该是每个类的预测概率的二维数组,而不是预测类的一维数组。

您还可以查看文档中提供的示例,其中包含:

y_probas = lr.predict_proba(X_test)