`.fit()` 的输入应该有 4 级。得到的数组形状为:(31500, 784)

数据挖掘 喀拉斯 张量流 美国有线电视新闻网
2022-03-04 02:14:47

我是 CNN 的新手,正在研究/使用 MNIST 数据集。将数据拆分为训练集和测试集后,我需要使用“ImageDataGenerator”。我使用的代码与 Keras API 站点上的代码相同。

形状如下:

print(X_train.shape,
X_test.shape
,y_train.shape
,y_test.shape)
(31500, 784) (10500, 784) (31500,) (10500,)

但是我突然遇到了一个ValueError。下面是我的代码:

import tensorflow as tf
from keras.preprocessing.image import ImageDataGenerator

datagen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    validation_split=0.2)

datagen.fit(X_train)

错误:

ValueError: Input to `.fit()` should have rank 4. Got array with shape: (31500, 784)

我该如何处理这个错误?有人可以帮忙吗?

1个回答

fit方法ImageDataGenerator期望输入具有四个维度 ( n_samples, height, width, n_channels)。您提供的数据只有两个维度,即n_samples, height*width*n_channels. 在使用方法之前尝试重塑数据fit,如下所示:

datagen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    validation_split=0.2)

datagen.fit(X_train.reshape(31500, 28, 28, 1))