我在尝试在 Keras 中使用生成器时遇到了问题。我有 TensorFlow 2.1 和 Python 3.7。我正在使用与 TensorFlow 捆绑的 Keras 版本
我为我的生成器定义了一个类,它派生自tf.keras.models.Sequential:
class RandomFramesFromPathsToVideos(tf.keras.models.Sequential):
def __init__(self, x_set, y_set, number_of_videos_per_batch, ort_session, frames_per_video=25, type_of_frame='cropped_frame'):
self.x, self.y = x_set, y_set
self.batch_size = number_of_videos_per_batch * frames_per_video
self.number_of_videos_per_batch = number_of_videos_per_batch
self.frames_per_video = frames_per_video
self.type_of_frame = type_of_frame
self.ort_session = ort_session
def __len__(self):
return int(np.ceil(len(self.x) / float(self.number_of_videos_per_batch)))
def __getitem__(self, idx):
bla, bla, bla...
return (np.array(devol_x, dtype=np.float32), np.array(devol_y, dtype=np.float32))
它之所以有效,是因为我完全能够使用这样的精心制作的循环逐项训练我的模型:
ds_train = RandomFramesFromPathsToVideos(x_set = set_of_paths, y_set = set_of_labels,
number_of_videos_per_batch=4,
ort_session = ort_session)
for i in range(0:len(ds_train)):
(x,y)=ds_train.__getitem__(i)
model.fit(x,y)
但是当我尝试使用应该使用的生成器时:
model.fit_generator(ds_train)
(我知道,现在在 TF2.1 中等效于model.fit(ds_train)),我收到错误消息:
ValueError: Failed to find data adapter that can handle input: <class 'functions.RandomFramesFromPathsToVideos'>, <class 'NoneType'>
到目前为止,我已经在 TensorFlow 的内部代码中跟踪了错误:
def select_data_adapter(x, y):
"""Selects a data adapter than can handle a given x and y."""
adapter_cls = [cls for cls in ALL_ADAPTER_CLS if cls.can_handle(x, y)]
if not adapter_cls:
# TODO(scottzhu): This should be a less implementation-specific error.
raise ValueError(
"Failed to find data adapter that can handle "
"input: {}, {}".format(
_type_name(x), _type_name(y)))
elif len(adapter_cls) > 1:
raise RuntimeError(
"Data adapters should be mutually exclusive for "
"handling inputs. Found multiple adapters {} to handle "
"input: {}, {}".format(
adapter_cls, _type_name(x), _type_name(y)))
return adapter_cls[0]
ALL_ADAPTER_CLS 的内容是:
result = {list: 7} [<class 'tensorflow.python.keras.engine.data_adapter.ListsOfScalarsDataAdapter'>, <class 'tensorflow.python.keras.engine.data_adapter.TensorLikeDataAdapter'>, <class 'tensorflow.python.keras.engine.data_adapter.GenericArrayLikeDataAdapter'>, <class 'tensor
0 = {ABCMeta} <class 'tensorflow.python.keras.engine.data_adapter.ListsOfScalarsDataAdapter'>
1 = {ABCMeta} <class 'tensorflow.python.keras.engine.data_adapter.TensorLikeDataAdapter'>
2 = {ABCMeta} <class 'tensorflow.python.keras.engine.data_adapter.GenericArrayLikeDataAdapter'>
3 = {ABCMeta} <class 'tensorflow.python.keras.engine.data_adapter.DatasetAdapter'>
4 = {ABCMeta} <class 'tensorflow.python.keras.engine.data_adapter.GeneratorDataAdapter'>
5 = {ABCMeta} <class 'tensorflow.python.keras.engine.data_adapter.KerasSequenceAdapter'>
6 = {ABCMeta} <class 'tensorflow.python.keras.engine.data_adapter.CompositeTensorDataAdapter'>
__len__ = {int} 7
我不明白为什么我的发电机不工作。有没有人知道我做错了什么?我的数据集是一大堆视频,无论如何都无法放入内存中。此外,我正在进行非标准转换,例如定位和裁剪出现在框架中的元素,因此我不能使用可能已经在 Keras 中实现的标准生成器。
非常感谢