在 TensorFlow 2.0 中进行预测/加载模型

数据挖掘 喀拉斯 张量流 预言
2022-02-27 23:14:31

我每天都使用 TensorFlow/Keras 对项目进行预测。一切正常,但我经常收到关于过渡到 TensorFlow 2.0 的警告,我想这周我最终会确保我的代码也能在新版本的库中工作。我在训练或保存模型期间没有遇到任何问题,但是在进行预测时,我收到了以下警告:

警告:tensorflow:您的输入数据已用完;中断训练。确保您的数据集或生成器至少可以生成 steps_per_epoch * epochs批次(在本例中为 10 个批次)。在构建数据集时,您可能需要使用 repeat() 函数。

事实证明它仍然按预期进行预测,但警告大大减慢了这个过程。我可以通过将steps=1参数传递给 来克服这个问题model.predict(),但这似乎是一种迂回的做事方式,在以前的 TensorFlow 版本中不需要这种方式。

我想知道我是否在这里遗漏了一些微不足道的东西。此外,TensorFlow 现在似乎无法弄清楚我是在做预测而不是训练,这在以前也不是问题。在相关的说明中,它可能也出现在之前的文档中,但我从未考虑过其中的batch_size论点model.predict()及其所服务的目的。

更新:

现在 Google Colab 将其默认版本更改为 TensorFlow 2,我决定再试一次。现在,代码仍然完全一样,但是当我尝试加载模型时出现错误:

警告:tensorflow:加载保存的优化器状态时出错。因此,您的模型从一个新初始化的优化器开始。

关于这个问题,TensorFlow github 页面上有一个未解决的问题:

https://github.com/tensorflow/tensorflow/issues/37968

解决后我会再次更新。

1个回答

我在 TF 2.1 中遇到了同样的问题。它与输入的形状/类型有关,即查询。就我而言,我解决了如下问题:(我的模型需要 3 个输入)

x = [[test_X[0][0]], [test_X[1][0]], [test_X[2][0]]]
MODEL.predict(x)

输出:

警告:tensorflow:您的输入数据已用完;中断训练。确保您的数据集或生成器至少可以生成 steps_per_epoch * epochs批次(在本例中为 7 个批次)。在构建数据集时,您可能需要使用 repeat() 函数。

数组([[2.053718]],dtype=float32)

解决方案:

x = [np.array([test_X[0][0]]), np.array([test_X[1][0]]), np.array([test_X[2][0]])]
MODEL.predict(x)

输出:

数组([[2.053718]],dtype=float32)