要同时使用组和分层,您似乎应该编写自己的代码。请注意,您将不可避免地在训练和/或测试集中丢失样本(除非您很幸运)。
实现它的一种方法是:
- 按组进行拆分(您可以使用GroupKFold方法
sklearn)
- 检查训练/测试集中目标的分布。
- 随机删除训练或测试集中的目标以平衡分布。
注意:使用这种算法可能会导致组消失。在平衡训练/测试集时,您可能不希望随机删除目标。
这是一个示例代码
import pandas as pd
import numpy as np
from sklearn.model_selection import GroupKFold
df = pd.DataFrame({
'person': ['a', 'b', 'c', 'aa', 'bb', 'cc', 'aaa', 'bbb', 'ccc'],
'group': [10, 10, 20, 20, 20, 20, 20, 30, 30],
'target': [1, 2, 2, 3, 2, 3, 1, 2, 3]
})
X = df['person']
y = df['target']
groups = df['group'].values
group_kfold = GroupKFold(n_splits=3)
group_kfold.get_n_splits(X, y, groups)
# First split by groups
for train_index, test_index in group_kfold.split(X, y, groups):
print("Groups split: TRAIN:", train_index, "TEST:", test_index)
y_train_grouped, y_test_grouped = y[train_index], y[test_index]
final_train_index = []
final_test_index = []
# Then balance the distributions for each target
for target in df['target'].unique():
target_train_index = y_train_grouped[y_train_grouped == target].index.tolist()
target_test_index = y_test_grouped[y_test_grouped == target].index.tolist()
n_training = len(target_train_index)
n_testing = len(target_test_index)
print("Target:" + str(target) + " - n_training:" + str(n_training) + " - n_testing:" + str(n_testing) +
" | target_train_index:" + str(target_train_index) + " - target_test_index:" + str(target_test_index))
# Shuffle to remove randomly
np.random.shuffle(target_train_index)
np.random.shuffle(target_test_index)
# Check if we need to remove samples from training or testing set
if n_training > n_testing:
while n_training > n_testing:
target_train_index.pop(0)
n_training = len(target_train_index)
if n_training < n_testing:
while n_training < n_testing:
target_test_index.pop(0)
n_testing = len(target_test_index)
# Append new indexes to global train/test indexes
final_train_index.append(target_train_index)
final_test_index.append(target_test_index)
# Flatten for readability
final_train_index = [item for sublist in final_train_index for item in sublist]
final_test_index = [item for sublist in final_test_index for item in sublist]
print("FINAL split: TRAIN:", final_train_index," TEST:", final_test_index, "\n")
编辑
使用分层交叉验证似乎不是强制性的(请参阅下面的链接),因此您可能会重新考虑使用它。
您可能会发现此链接很有用:
关于过采样/欠采样,我认为如果没有更多关于数据分布的详细信息以及您的类的不平衡程度,很难回答。