与 fit_generator 一起使用时,Keras 序列生成器会导致大量内存使用

数据挖掘 喀拉斯 张量流
2022-02-15 16:04:43

在使用非常大的数据集进行训练时,我想使用 fit_generator 来稳定内存使用情况。Q1:据我了解,生成器将批次放入队列中,该队列由 keras 的 fit_generator 函数获取,以在该批次上训练模型。在那个批次下雨之后,它应该从内存中释放那个批次吗?

Q2:在我的示例中,内存随着每个批次的增加而增加。没有看到内存清理。此外,内存确实会切换到交换内存。之后内存保持稳定,但交换内存会增加,直到程序崩溃。

Q3:我现在仅将 cpu 用于测试目的。当一切都必须在计算机的 ram 上运行并且一切都用 cpu 计算时,fit_generator 是否可以按预期工作?

这是我的代码:

class DataGenerator(Sequence):
    def __init__(self, batch_size, id_list, dim, shuffle=True):
        # batch_size: batch_size at each iteration
        self.batch_size = batch_size
        self.id_list = id_list
        self.dim = dim
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        # Denotes the number of batches per epoch
        return int(np.floor(len(self.id_list)/self.batch_size))

    def __getitem__(self, index):
        '''''
        MAIN function:
        - index tells the number of batches that have already been processed
        :return: X and y with shape [batch_size, window_size, features]        
        '''''
        X, y = self.__data_generation(self.id_list[index])
        return X, y


    @profile
    def on_epoch_end(self):
        if self.shuffle:
            random.shuffle(self.id_list)
        gc.collect()

    def __data_generation(self, fname):
        # 1. load data from specified path
        data = self.__load_data(filename=fname)
        data_x = np.squeeze(data[0, 0])
        data_y = np.squeeze(data[1, 0])
        return data_x, data_y

        # --------------------------------------- utility functions data loading

    def __load_data(self, filename):
        data = np.load(filename, allow_pickle=True)
        # load data from the filename
        return data

生成器比放入 fit_generator 函数

train_gen = DataGenerator(batch_size=1,
                             id_list=train_files,
                             dim=(256, 6),
                             shuffle=True)

model.fit_generator(generator=train_gen,
                       epochs=2,
                       verbose=1,
                       use_multiprocessing=True,
                       workers=5,
                       max_queue_size=10)

我想知道我的概念是否错误,或者我是否误解了 fit_generator 函数的使用。

1个回答

Q1:是的,理论上应该释放内存 Q2:这不是一个问题 Q3:是的,fit_generator 可以在 CPU 上正常工作。

根据我自己实施它们的经验,据我所知,您的数据生成器看起来不错,所以我只能建议现在如何继续找到根本原因,作为对实际问题的潜在补救措施。

我要尝试的第一件事是将多处理设置为 False 并将工作人员设置为 1。我自己在使用多处理功能时遇到了一些问题,虽然我不记得它是与内存相关还是与种族有关 -条件相关,我会先检查一切是否正常。我要看的第二件事是加载文件的替代方法,看看这些东西是否仍然会导致内存泄漏。