我是 CNN 的新手,正在研究/使用 MNIST 数据集。将数据拆分为训练集和测试集后,我需要使用“ImageDataGenerator”。我使用的代码与 Keras API 站点上的代码相同。
形状如下:
print(X_train.shape,
X_test.shape
,y_train.shape
,y_test.shape)
(31500, 784) (10500, 784) (31500,) (10500,)
但是我突然遇到了一个ValueError。下面是我的代码:
import tensorflow as tf
from keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(
featurewise_center=True,
featurewise_std_normalization=True,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True,
validation_split=0.2)
datagen.fit(X_train)
错误:
ValueError: Input to `.fit()` should have rank 4. Got array with shape: (31500, 784)
我该如何处理这个错误?有人可以帮忙吗?