数据集映射函数错误:TypeError:“EagerPyFunc”操作的“输入”参数的预期列表,而不是张量

数据挖掘 张量流 地图减少
2022-02-26 00:07:34

我目前正在尝试编写一个脚本来创建一个 TFRecord 文件。

因此,我遵循官方 tensorflow 网站上的说明:https ://www.tensorflow.org/tutorials/load_data/tfrecord#writing_a_tfrecord_file

但是,当将 map 函数应用于 Dataset 的每个元素时,我会收到一个我不理解的错误。

这是我的代码(应该是可复制和可粘贴的):

import numpy as np
import tensorflow as tf
from tensorflow.data import Dataset


def generate_random_img_data(n_count=10, patch_size=5):
    return np.random.randint(low=0, high=256, size=(n_count, patch_size, patch_size, 3))


def as_int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def serialize_one_image(img):
    features = {}
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            for k in range(img.shape[2]):
                features.update({str(i) + "_" + str(j) + "_" + str(k) : as_int64_feature(img[i,j,k]) })
    example_proto = tf.train.Example(features=tf.train.Features(feature=features))
    return example_proto.SerializeToString()


def tf_serialize_one_image(img):
    tf_string = tf.py_function(serialize_one_image, img, tf.string)
    return tf.reshape(tf_string,())


ds = Dataset.from_tensor_slices(generate_random_img_data())
ds_serialized = ds.map(tf_serialize_one_image) # <--- not working

运行此代码时出现错误:

TypeError: in user code:

    <ipython-input-116-ec81a7077c70>:25 tf_serialize_one_image  *
        tf_string = tf.py_function(serialize_one_image, img, tf.string)
    /Users/Tom/ML-Projects/vdst/lib/python3.7/site-packages/tensorflow/python/ops/script_ops.py:455 eager_py_func  **
        func=func, inp=inp, Tout=Tout, eager=True, name=name)
    /Users/Tom/ML-Projects/vdst/lib/python3.7/site-packages/tensorflow/python/ops/script_ops.py:341 _internal_py_func
        name=name)
    /Users/Tom/ML-Projects/vdst/lib/python3.7/site-packages/tensorflow/python/ops/gen_script_ops.py:69 eager_py_func
        name=name)
    /Users/Tom/ML-Projects/vdst/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:385 _apply_op_helper
        (input_name, op_type_name, values))

    TypeError: Expected list for 'input' argument to 'EagerPyFunc' Op, not Tensor("args_0:0", shape=(5, 5, 3), dtype=int64).

我在这里到底做错了什么?

1个回答

好吧,这是一个非常愚蠢的错误。

如错误消息所述,需要输入列表

所以替换这部分代码:

tf_string = tf.py_function(serialize_one_image, img, tf.string)

对此

tf_string = tf.py_function(serialize_one_image, [img], tf.string)

即将img对象包装到列表中解决了这个问题。

现在它按预期工作。

无论如何感谢您的阅读。