我知道我可以在 Keras 中使用 ModelCheckpoint来检查每个时期的模型(或每隔几个时期,这取决于我想要什么)。
我从 fit_generator 获取每个小批量的数据,评估每个小批量需要很长时间。我希望能够通过 minibatch 而不是 epoch 来检查点。我怎样才能在 Keras 中做到这一点?
我知道我可以在 Keras 中使用 ModelCheckpoint来检查每个时期的模型(或每隔几个时期,这取决于我想要什么)。
我从 fit_generator 获取每个小批量的数据,评估每个小批量需要很长时间。我希望能够通过 minibatch 而不是 epoch 来检查点。我怎样才能在 Keras 中做到这一点?
您必须为此编写自定义回调。步骤是:
子类 ModelCheckpoint ( https://github.com/keras-team/keras/blob/master/keras/callbacks.py ) 或者如果您不需要文件名模式等,则创建新的。
添加将在每个批次结束时调用的方法
class BatchModelCheckpoint(keras.callbacks.Callback): def on_batch_end(self, batch, logs=None): self.model.save(filepath, overwrite=True)