使用 ImageDataGenerator() 和 flow_from_directory() 生成平衡批次

数据挖掘 分类 多类分类 阶级失衡
2022-02-25 17:50:18

嗨,我是 python 和深度学习的新手。我正在做一个多类分类。我的 3 类数据集是不平衡的,类分别占 50%、40% 和 20%。我正在尝试生成具有平衡类的小批量。class_weight用来生成平衡的批次,fit_generator()但我怀疑它是否真的有效,因为生成的批次train_datagen.flow_from_directory()不平衡。生成的批次的权重约为 [0.43, 0.38, 0.19]。我的代码如下:

train_datagen = ImageDataGenerator(rescale=1./255,
                                    featurewise_center=True,
                                    rotation_range=30,
                                    width_shift_range=0.3,
                                    height_shift_range=0.3,
                                    shear_range=0.2,
                                    zoom_range=0.2,
                                    horizontal_flip=True,
                                    fill_mode='constant')

#Training Set
train_set = train_datagen.flow_from_directory(
                                             directory=train_folder,
                                             target_size=input_shape[:2],
                                             batch_size=32,
                                             shuffle=True,
                                             class_mode='categorical')
#Validation Set
val_set = test_datagen.flow_from_directory(
                                            directory=val_folder,
                                            target_size=input_shape[:2],
                                            batch_size = 32,
                                            class_mode='categorical',
                                            shuffle=True)

call_backs = [EarlyStopping(monitor='val_loss', patience=train_patience),
             ModelCheckpoint(filepath=output_model, monitor='val_loss', save_best_only=True)]

class_weights = class_weight.compute_class_weight(
               'balanced',
                np.unique(train_set.classes), 
                train_set.classes)

history = model.fit_generator(
          train_set,
          steps_per_epoch=2000 // batch_size,
          epochs=300,
          validation_data=val_set,
           validation_steps=800 // batch_size,
           class_weight= class_weights,
           verbose=1,
           callbacks=call_backs)

此代码是否足以生成用于训练的平衡批次?我检查了从 train_set 生成的批次,它们在 [0.43, 0.38, 0.19] 左右。任何建议将不胜感激谢谢。

1个回答

class_weight不影响批次的组成。相反,它将权重应用于取决于类权重的损失函数。