我的训练准确率比我的测试准确率好,因此我认为我的模型过度拟合并尝试了交叉验证。该模型进一步退化。我的输入数据是否需要进一步清理并提高质量?
请分享您的想法,这里可能出了什么问题。
我的功能get_score:
def get_score(model, X_train, X_test, y_train, y_test):
model.fit(X_train, y_train.values.ravel())
pred0 = model.predict(X_test)
return accuracy_score(y_test, pred0)
逻辑:
print('*TRAIN* Accuracy Score => '+str(accuracy_score(y_train, m.predict(X_train)))) # LinearSVC() used
print('*TEST* Accuracy Score => '+str(accuracy_score(y_test, pred))) # LinearSVC() used
print("... Cross Validation begins...")
y0 = pd.DataFrame(y)
y0.reset_index(drop=True, inplace=True)
print(X.shape)
print(y0.shape)
kf = KFold(n_splits = 10)
e = []
for train_index, test_index in kf.split(X):
X_train, X_test, y_train, y_test = X.iloc[train_index], X.iloc[test_index],y0.iloc[train_index], y0.iloc[test_index]
print(train_index, test_index)
e.append(get_score(LinearSVC(random_state=777),X_train, X_test, y_train, y_test))
print("Finally :: "+str(np.mean(e)))
输出:
*TRAIN* Accuracy Score => 0.9451327433628318
*TEST* Accuracy Score => 0.6597345132743363
... Cross Validation begins...
(9040, 6458)
(9040, 1)
[ 904 905 906 ... 9037 9038 9039] [ 0 1 2 ... 901 902 903]
[ 0 1 2 ... 9037 9038 9039] [ 904 905 906 ... 1805 1806 1807]
[ 0 1 2 ... 9037 9038 9039] [1808 1809 1810 ... 2709 2710 2711]
[ 0 1 2 ... 9037 9038 9039] [2712 2713 2714 ... 3613 3614 3615]
[ 0 1 2 ... 9037 9038 9039] [3616 3617 3618 ... 4517 4518 4519]
[ 0 1 2 ... 9037 9038 9039] [4520 4521 4522 ... 5421 5422 5423]
[ 0 1 2 ... 9037 9038 9039] [5424 5425 5426 ... 6325 6326 6327]
[ 0 1 2 ... 9037 9038 9039] [6328 6329 6330 ... 7229 7230 7231]
[ 0 1 2 ... 9037 9038 9039] [7232 7233 7234 ... 8133 8134 8135]
[ 0 1 2 ... 8133 8134 8135] [8136 8137 8138 ... 9037 9038 9039]
Finally :: 0.32499999999999996
>>>
编辑 -1- 添加“e”的值
[0.08075221238938053, 0.413716814159292, 0.05752212389380531, 0.15376106194690264, 0.14712389380530974, 0.4668141592920354, 0.6946902654867256, 0.7112831858407079, 0.33738938053097345, 0.18694690265486727]
编辑 -2- 将shuffle=True参数添加到KFold()
结果:
[ 0 1 2 ... 9037 9038 9039] [ 4 5 10 ... 9007 9011 9024]
[ 0 1 2 ... 9037 9038 9039] [ 21 43 44 ... 9018 9035 9036]
[ 0 2 3 ... 9037 9038 9039] [ 1 20 60 ... 9023 9031 9034]
[ 0 1 2 ... 9036 9037 9038] [ 6 25 27 ... 9010 9025 9039]
[ 0 1 2 ... 9037 9038 9039] [ 15 16 28 ... 9029 9030 9033]
[ 0 1 2 ... 9037 9038 9039] [ 3 12 40 ... 9015 9017 9028]
[ 0 1 2 ... 9037 9038 9039] [ 7 8 23 ... 9013 9014 9027]
[ 0 1 3 ... 9035 9036 9039] [ 2 18 19 ... 9019 9037 9038]
[ 0 1 2 ... 9037 9038 9039] [ 24 37 39 ... 9012 9016 9026]
[ 1 2 3 ... 9037 9038 9039] [ 0 9 14 ... 9020 9021 9032]
[0.6504424778761062, 0.6736725663716814, 0.6969026548672567, 0.6692477876106194, 0.6769911504424779, 0.6382743362831859, 0.6692477876106194, 0.6648230088495575, 0.6648230088495575, 0.6814159292035398]
Finally :: 0.6685840707964601
