我目前正在尝试编写一个脚本来创建一个 TFRecord 文件。
因此,我遵循官方 tensorflow 网站上的说明:https ://www.tensorflow.org/tutorials/load_data/tfrecord#writing_a_tfrecord_file
但是,当将 map 函数应用于 Dataset 的每个元素时,我会收到一个我不理解的错误。
这是我的代码(应该是可复制和可粘贴的):
import numpy as np
import tensorflow as tf
from tensorflow.data import Dataset
def generate_random_img_data(n_count=10, patch_size=5):
return np.random.randint(low=0, high=256, size=(n_count, patch_size, patch_size, 3))
def as_int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def serialize_one_image(img):
features = {}
for i in range(img.shape[0]):
for j in range(img.shape[1]):
for k in range(img.shape[2]):
features.update({str(i) + "_" + str(j) + "_" + str(k) : as_int64_feature(img[i,j,k]) })
example_proto = tf.train.Example(features=tf.train.Features(feature=features))
return example_proto.SerializeToString()
def tf_serialize_one_image(img):
tf_string = tf.py_function(serialize_one_image, img, tf.string)
return tf.reshape(tf_string,())
ds = Dataset.from_tensor_slices(generate_random_img_data())
ds_serialized = ds.map(tf_serialize_one_image) # <--- not working
运行此代码时出现错误:
TypeError: in user code:
<ipython-input-116-ec81a7077c70>:25 tf_serialize_one_image *
tf_string = tf.py_function(serialize_one_image, img, tf.string)
/Users/Tom/ML-Projects/vdst/lib/python3.7/site-packages/tensorflow/python/ops/script_ops.py:455 eager_py_func **
func=func, inp=inp, Tout=Tout, eager=True, name=name)
/Users/Tom/ML-Projects/vdst/lib/python3.7/site-packages/tensorflow/python/ops/script_ops.py:341 _internal_py_func
name=name)
/Users/Tom/ML-Projects/vdst/lib/python3.7/site-packages/tensorflow/python/ops/gen_script_ops.py:69 eager_py_func
name=name)
/Users/Tom/ML-Projects/vdst/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:385 _apply_op_helper
(input_name, op_type_name, values))
TypeError: Expected list for 'input' argument to 'EagerPyFunc' Op, not Tensor("args_0:0", shape=(5, 5, 3), dtype=int64).
我在这里到底做错了什么?