Tensorflow API:度量标准“tf.keras.metrics.TopKCategoricalAccuracy”有什么作用?

数据挖掘 机器学习 神经网络 喀拉斯 张量流 公制
2022-02-17 15:48:13

根据 API 文档,这个指标

“计算目标出现在前 K 个预测中的频率。”

但是为什么下面的代码会产生结果 1?0.95>0.9>0.8>0.1>0.05,0.95和0.8都导致预测为1,结果不应该是2吗?

m = tf.keras.metrics.TopKCategoricalAccuracy()
m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
print('Final result: ', m.result().numpy())  # Final result: 1.0
1个回答

的结果tf.keras.metrics.TopKCategoricalAccuracy()将介于 0 和 1 之间。参数 k 的默认值为 5。

结果为 1,因为对于两个样本,实际值都在前 5 个预测范围内。