我正在使用以下 scikit-learn 设置训练线性模型:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.svm import LinearSVC
from sklearn.metrics import f1_score
from sklearn.model_selection import cross_val_score
[...]
random_state=786543
max_iter=5, tol=None)
clf = LinearSVC(random_state=random_state, dual=True, C=1.5)
X_train, X_test, y_train, y_test, i_train, i_test = train_test_split(feature_matrix, y, indices, test_size=0.33, random_state=random_state)
clf.fit(X_train, y_train.values)
predicted_train = clf.predict(X_train)
predicted_test = clf.predict(X_test)
print('Train Accuracy: ' + str(np.mean(y_train == predicted_train)))
print('Test Accuracy: ' + str(np.mean(y_test == predicted_test)))
print('Test F1 micro: ' + str(f1_score(y_test, predicted_test, average='micro')))
print('Test F1 macro: ' + str(f1_score(y_test, predicted_test, average='macro')))
print('Test F1 weighted: ' + str(f1_score(y_test, predicted_test, average='weighted')))
训练精度:0.985129495926343
测试精度:0.9601936525013448
测试F1微:0.9601936525013448
测试F1宏:0.9000889214688401
测试 F1 加权:0.9590331562500389
但现在我跑
scores = cross_val_score(clf, feature_matrix, y, cv=5, scoring='f1_macro')
print(scores)
数组([0.65860981, 0.84306338, 0.82113645, 0.83414211, 0.64665942])
如何解释这种差异?我使用不同的随机状态对此进行了测试。
需要考虑的几点:
- 我有多个类(但每个样本只有一个标签)
- 数据集是倾斜的(所以有些类有很多样本,有些类很少)
- 我有 45066 个样本,5222 个特征,259 个类
每类样本数为:
sorted(list(np.unique(y, return_counts=True)[1]))
[1, 1, 1, 1, 1, 1, 1, 2, 3, 4, 4, 4, 4, 4, 7, 7, 8, 9, 9, 10, 10, 10, 10, 10, 10 , 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13 , 13, 13, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 18, 18 , 18, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 21, 21, 21, 22, 22, 22, 22, 23, 23, 24, 24, 24 , 25, 25, 25, 26, 26, 27, 27, 27, 27, 27, 29, 29, 29, 29, 30, 30, 30, 30, 32, 32, 32, 34, 34, 35, 35 , 35, 36, 36, 36, 36, 36, 36, 37, 37, 37, 37, 38, 38, 38, 38, 40, 40, 40, 41, 41, 45, 45, 45, 46, 46 , 46, 46, 47, 47, 47, 48, 49, 50, 50, 52, 55, 56, 59, 59, 60, 60, 61, 61, 61, 65, 65, 67, 67, 69, 72 , 73, 74, 75, 77, 77, 79, 80, 84, 85, 87, 93, 96, 97, 97, 103, 110, 112, 117, 123, 130, 139, 139, 141, 143, 146 , 146, 147, 147, 150, 159, 161, 169, 170, 177, 180, 180, 189, 191,196、198、199、201、202、203、203、208、211、230、236、249、255、264、268、269、300、332、347、356、358、364、388、433、469、 476、484、548、652、698、723、748、753、807、815、1013、1200、1222、1243、1274、1447、1643、1741、2900、3909、4627]