如何将 decode_predictions() 用于非 Imagenet 模型..?

数据挖掘 神经网络 深度学习 喀拉斯 美国有线电视新闻网
2022-02-24 20:19:19

我知道 decode_predictions() 仅适用于 VGG16 等模型的 imagenet 数据集(1000 个类)。但请考虑我的情况。

我的场景

我使用了 vgg16 预训练模型,并添加了我自己的权重。

在此处输入图像描述

所以这被证明是一个非 Imagenet 模型。我已经提到了 classes=9,因为我只用 9 个类训练了我以前的模型。

在此处输入图像描述

所以现在要找到预测,我可以使用predict()方法,然后我的print(answer)会给出相应的类标签。但实际上我需要打印类名。是否有可能获得班级名称?如果是的话,谁能解释我怎么做。?

1个回答

在深度学习中,当你执行预测时,你会得到预测,它是一个包含 9 个概率的数组。所以首先对它们进行以下操作。

import numpy as np

prediction = model.prediction(test_data) 
# prediction will contain [[0.6, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05]]

prediction = np.argmax(prediction[0])
# Now predition hold the index of the (0.6) i.e max probability value

现在,您应该有一个字典,其中包含 0、1、2、.. 8 的键和 classname1、classname2、...classname9 的值

这是您将类名作为输出的方式