model.predict 不断预测相同的错误类别

数据挖掘 机器学习 张量流
2022-02-23 12:47:07

我有一个简单的 oxford102 (102 UK flowers) tensorflow 模型和来自这里的Android 应用程序

我想看看我是否可以在子项目 FlowerML 中的 python 代码中添加识别(model.predict),这是首先生成模型的 python 代码。

将以下内容添加到 FlowerML2.py 以识别单朵花:

# sequence of calls adapting image then calling predict:
# https://datascience.stackexchange.com/questions/31167/how-to-predict-an-image-using-saved-model
test_image = image.load_img('d:/ML/images/CallaLily-1.jpg', target_size=(224, 224))
print("test_image PIL size " + str(test_image.size))
test_image_arr = image.img_to_array(test_image)
print(test_image_arr.shape)
# double-check visually
image.save_img('d:/ml/images/WhatIsIt-1.jpg', test_image_arr)
test_image = np.expand_dims(test_image_arr, axis=0)
print(test_image.shape)
result2 = model.predict(tf.convert_to_tensor(test_image), batch_size=1, verbose=2)
print("\nnp.where result2: " + str(np.where(result2 == result2.max())) + " max = " + str(result2.max()) + "\n")

该模型在 Android 应用程序中正确识别了 Calla Lily,但是当我如上所述查询模型时,它一直告诉我错误的类。'Keeps' 指的是许多未在上面的修改 - 甚至是不同的花。这就像模型,在这种情况下(但不是在 Android 应用程序中),被困在其头层的输出 n=1 上。

由于它在 Android 中工作,我的假设是我没有正确调用 predict - 尽管究竟出了什么问题我不能说而且我没有什么经验,到目前为止,深入研究模型是如何工作的,这可能是你需要做的当它们不能正常工作时。

1个回答

弄清楚了。需要标准化我要预测的图像[1,1]而不是我认为的标准[0,1]. 似乎这取决于模型。我正在处理的项目使用迁移学习并导入此模型

虽然这个模型需要如何将输入标准化为[1,1].