- 我研究了自动编码器并尝试实现一个简单的。
- 我已经建立了一个带有一个隐藏层的模型。
- 我使用 MNIST 数字数据集运行它,并在自动编码器之前和之后绘制了数字。
- 我看到了一些使用大小为 32 或 64 的隐藏层的示例,我尝试过它并没有给出相同(或接近)源图像。
- 我尝试将隐藏层的大小更改为 784(与输入大小相同,只是为了测试模型),但得到了相同的结果。
我错过了什么?为什么网络上的示例显示出良好的结果,而当我测试它们时,我得到了不同的结果?
import tensorflow as tf
from tensorflow.python.keras.layers import Input, Dense
from tensorflow.python.keras.models import Model, Sequential
from tensorflow.python.keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
# Build models
hiden_size = 784 # After It didn't work for 32 , I have tried 784 which didn't improve results
input_layer = Input(shape=(784,))
decoder_input_layer = Input(shape=(hiden_size,))
hidden_layer = Dense(hiden_size, activation="relu", name="hidden1")
autoencoder_output_layer = Dense(784, activation="sigmoid", name="output")
autoencoder = Sequential()
autoencoder.add(input_layer)
autoencoder.add(hidden_layer)
autoencoder.add(autoencoder_output_layer)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
encoder = Sequential()
encoder.add(input_layer)
encoder.add(hidden_layer)
decoder = Sequential()
decoder.add(decoder_input_layer)
decoder.add(autoencoder_output_layer)
#
# Prepare Input
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
#
# Fit & Predict
autoencoder.fit(x_train, x_train,
epochs=50,
batch_size=256,
validation_data=(x_test, x_test),
verbose=1)
encoded_imgs = encoder.predict(x_test)
decoded_imgs = decoder.predict(encoded_imgs)
#
# Show results
n = 10 # how many digits we will display
plt.figure(figsize=(20, 4))
for i in range(n):
# display original
ax = plt.subplot(2, n, i + 1)
plt.imshow(x_test[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# display reconstruction
ax = plt.subplot(2, n, i + 1 + n)
plt.imshow(decoded_imgs[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()

