Pytorch:如何确保每批中都存在所有标签

数据挖掘 机器学习 Python 深度学习 火炬
2022-02-24 22:50:25

如何确保每批都有带有所有标签的样品?例如,考虑带有正面和负面标签的情感分析问题。

tokens = tokenizer.batch_encode_plus(text.tolist(),max_length = max_seq_len,pad_to_max_length=True,truncation=True, return_token_type_ids=False)    
seq = torch.tensor(tokens['input_ids'])
mask = torch.tensor(tokens['attention_mask'])
y = torch.tensor(labels.tolist())    
data = TensorDataset(seq, mask,y)
data_sampler = RandomSampler(data)
data_dataloader = DataLoader(data, sampler=data_sampler, batch_size=batch_size)

我想要像这样的批次

Batch-1 ['positive','positive','positive','negative']
Batch-2 ['negative','negative','positive','negative']

每个批次都包含所有标签。

1个回答

PyTorch 中有关于分层抽样选项。

但是,如果这不能满足您的需求,我的建议是要么使用scikit-learn适配 PyTorch 代码,要么阅读 scikit-learn 代码并将其适配到 PyTorch。