TensorFlow 文本预测如何在没有 softmax 激活的情况下进行预测

数据挖掘 张量流 rnn
2022-03-05 22:02:02

在这里的 Colab 笔记本: RNN text generation in 中def generate_text(),有 predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()

我在tf.random.categorical这里查看: stackoverflow 答案

并且有点理解它是如何工作的。

我试图调试/弄清楚它对打印语句的作用:

for i in range(num_generate):
      predictions = model(input_eval)
      # remove the batch dimension
      predictions = tf.squeeze(predictions, 0)

      # using a categorical distribution to predict the word returned by the model
      predictions = predictions / temperature
      predicted_id1 = tf.random.categorical(predictions, num_samples=1)
      predicted_id2 = tf.random.categorical(predictions, num_samples=1)[-1,0]
      predicted_id3 = tf.random.categorical(predictions, num_samples=1)[0,0]
      predicted_id4 = tf.random.categorical(predictions, num_samples=1)[-2,0]
      predicted_id5 = tf.random.categorical(predictions, num_samples=1)[1,0]
      predicted_id6 = tf.random.categorical(predictions, num_samples=1)[2,0]
      predicted_id7 = tf.random.categorical(predictions, num_samples=1)[3,0]
      #predicted_id8 = tf.random.categorical(predictions, num_samples=1)[4,0]

      predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy() #index 0 for  

      print("predicted_id1", predicted_id1)
      print("predicted_id2", predicted_id2)
      print("predicted_id3", predicted_id3)
      print("predicted_id4", predicted_id4)
      print("predicted_id5", predicted_id5)
      print("predicted_id6", predicted_id6)
      print("predicted_id7", predicted_id7)

      print("predicted_id", predicted_id)

这是输出:

predicted_id1 tf.Tensor(
[[19]
 [33]
 [ 3]
 [35]
 [ 4]
 [64]
 [22]], shape=(7, 1), dtype=int64)
predicted_id2 tf.Tensor(35, shape=(), dtype=int64)
predicted_id3 tf.Tensor(19, shape=(), dtype=int64)
predicted_id4 tf.Tensor(3, shape=(), dtype=int64)
predicted_id5 tf.Tensor(38, shape=(), dtype=int64)
predicted_id6 tf.Tensor(26, shape=(), dtype=int64)
predicted_id7 tf.Tensor(36, shape=(), dtype=int64)
predicted_id 29

所以看起来存在某种分布,并且一些指数是从该分布中挑选出来的,但这种情况下的实际预测29并没有出现在分布中,所以我很困惑。分布中的元素不是文本中字符的整数 ID 吗?我在 Udacity DLND 中学到的一种方法是将概率分配给预测的下一个角色并选择,argmax所以请随时启发我。

1个回答

logits 有根据词汇的 len 和模型训练的分布。因此,您可以使用 np.argmax(logits) 来获得预测,但通常对于生成脚本的应用程序更有趣的是要考虑一个偶然因素,在这种情况下是用于获取的函数“random_categorical”根据概率的值。值 29 出现在分布中,您没有看到,因为您多次执行该函数并且每次执行,都会出现“随机值”。如果您想要确切的值,请使用 np.argmax。