fit_generator 不接受自定义 keras 数据集生成器

数据挖掘 喀拉斯 张量流 数据集
2022-02-17 13:20:38

我在尝试在 Keras 中使用生成器时遇到了问题。我有 TensorFlow 2.1 和 Python 3.7。我正在使用与 TensorFlow 捆绑的 Keras 版本

我为我的生成器定义了一个类,它派生自tf.keras.models.Sequential

class RandomFramesFromPathsToVideos(tf.keras.models.Sequential):
    def __init__(self, x_set, y_set, number_of_videos_per_batch, ort_session, frames_per_video=25, type_of_frame='cropped_frame'):
        self.x, self.y = x_set, y_set
        self.batch_size = number_of_videos_per_batch * frames_per_video
        self.number_of_videos_per_batch = number_of_videos_per_batch
        self.frames_per_video = frames_per_video
        self.type_of_frame = type_of_frame
        self.ort_session = ort_session



    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.number_of_videos_per_batch)))

    def __getitem__(self, idx):

        bla, bla, bla...


        return (np.array(devol_x, dtype=np.float32), np.array(devol_y, dtype=np.float32))

它之所以有效,是因为我完全能够使用这样的精心制作的循环逐项训练我的模型:

ds_train = RandomFramesFromPathsToVideos(x_set = set_of_paths, y_set = set_of_labels, 
    number_of_videos_per_batch=4,
    ort_session = ort_session)

for i in range(0:len(ds_train)):
    (x,y)=ds_train.__getitem__(i)
    model.fit(x,y)

但是当我尝试使用应该使用的生成器时:

model.fit_generator(ds_train)

(我知道,现在在 TF2.1 中等效于model.fit(ds_train)),我收到错误消息:

ValueError: Failed to find data adapter that can handle input: <class 'functions.RandomFramesFromPathsToVideos'>, <class 'NoneType'>

到目前为止,我已经在 TensorFlow 的内部代码中跟踪了错误:

def select_data_adapter(x, y):
  """Selects a data adapter than can handle a given x and y."""
  adapter_cls = [cls for cls in ALL_ADAPTER_CLS if cls.can_handle(x, y)]
  if not adapter_cls:
    # TODO(scottzhu): This should be a less implementation-specific error.
    raise ValueError(
        "Failed to find data adapter that can handle "
        "input: {}, {}".format(
            _type_name(x), _type_name(y)))
  elif len(adapter_cls) > 1:
    raise RuntimeError(
        "Data adapters should be mutually exclusive for "
        "handling inputs. Found multiple adapters {} to handle "
        "input: {}, {}".format(
            adapter_cls, _type_name(x), _type_name(y)))
  return adapter_cls[0]

ALL_ADAPTER_CLS 的内容是:

result = {list: 7} [<class 'tensorflow.python.keras.engine.data_adapter.ListsOfScalarsDataAdapter'>, <class 'tensorflow.python.keras.engine.data_adapter.TensorLikeDataAdapter'>, <class 'tensorflow.python.keras.engine.data_adapter.GenericArrayLikeDataAdapter'>, <class 'tensor
 0 = {ABCMeta} <class 'tensorflow.python.keras.engine.data_adapter.ListsOfScalarsDataAdapter'>
 1 = {ABCMeta} <class 'tensorflow.python.keras.engine.data_adapter.TensorLikeDataAdapter'>
 2 = {ABCMeta} <class 'tensorflow.python.keras.engine.data_adapter.GenericArrayLikeDataAdapter'>
 3 = {ABCMeta} <class 'tensorflow.python.keras.engine.data_adapter.DatasetAdapter'>
 4 = {ABCMeta} <class 'tensorflow.python.keras.engine.data_adapter.GeneratorDataAdapter'>
 5 = {ABCMeta} <class 'tensorflow.python.keras.engine.data_adapter.KerasSequenceAdapter'>
 6 = {ABCMeta} <class 'tensorflow.python.keras.engine.data_adapter.CompositeTensorDataAdapter'>
 __len__ = {int} 7

我不明白为什么我的发电机不工作。有没有人知道我做错了什么?我的数据集是一大堆视频,无论如何都无法放入内存中。此外,我正在进行非标准转换,例如定位和裁剪出现在框架中的元素,因此我不能使用可能已经在 Keras 中实现的标准生成器。

非常感谢

1个回答

我按照教程上的说明进行操作,但似乎有一个错误,或者 TensorFlow 的内部代码在编写后发生了变化。我是这样定义我的班级的:

class DataGenerator(tf.keras.utils.Sequence):

后来我也试过这个:

class DataGenerator(tf.keras.models.Sequential):

(顺便说一句,这是完全不同的,是完全错误的)

检查一些适配器的代码,我看到了:

class KerasSequenceAdapter(GeneratorDataAdapter):
  """Adapter that handles `keras.utils.Sequence`."""

  @staticmethod
  def can_handle(x, y=None):
    return isinstance(x, data_utils.Sequence)

以及我的其他可疑适配器:

class GeneratorDataAdapter(DataAdapter):
  """Adapter that handles python generators and iterators."""

  @staticmethod
  def can_handle(x, y=None):
    return ((hasattr(x, "__next__") or hasattr(x, "next"))
            and hasattr(x, "__iter__")
            and not isinstance(x, data_utils.Sequence))

其中 data_utils 定义为:

from tensorflow.python.keras.utils import data_utils

因此,我将班级的定义更改为:

from tensorflow.python.keras.utils import data_utils
class RandomFramesFromPathsToVideos(data_utils.Sequence):

现在它可以工作了。

我有,所以,两个选择。使用 data_utils.Sequence 的子类就是其中之一。在这种情况下,我们谈论的是“Keras 生成器”(由 KerasSequenceAdapter 在 TensorFlow 中处理),您需要在其中定义两个方法:__len__ 和 __getitem__

另一个是构建一个不是从 data_utils.Sequence 派生的新类,并定义方法 __iter__ 和 __next__ (或简称 next)。在这种情况下,我们定义了一个标准的 Python 生成器,它将由 TensorFlow 中的 GeneratorDataAdapter 处理

请记住,Keras 生成器与 Python 生成器不同。起初,这让我很困惑。