如何在 Keras 中通过 minibatch 检查点

数据挖掘 喀拉斯
2022-03-16 23:04:50

我知道我可以在 Keras 中使用 ModelCheckpoint来检查每个时期的模型(或每隔几个时期,这取决于我想要什么)。

我从 fit_generator 获取每个小批量的数据,评估每个小批量需要很长时间。我希望能够通过 minibatch 而不是 epoch 来检查点。我怎样才能在 Keras 中做到这一点?

1个回答

您必须为此编写自定义回调。步骤是:

  1. 子类 ModelCheckpoint ( https://github.com/keras-team/keras/blob/master/keras/callbacks.py ) 或者如果您不需要文件名模式等,则创建新的。

  2. 添加将在每个批次结束时调用的方法

class BatchModelCheckpoint(keras.callbacks.Callback):
     def on_batch_end(self, batch, logs=None):
        self.model.save(filepath, overwrite=True)