from tensorflow.keras.datasets import mnist
(X_train, _), (X_test, _) = mnist.load_data()
print('X_train shape:', X_train.shape)
print('X_test shape:', X_test.shape)
# Scale X to range between 0 and 1
X_train = X_train.astype('float32') / 255.
X_test = X_test.astype('float32') / 255.
# Reshape X to (n_samples/batch_size, height, width, n_channels)
X_train = X_train.reshape(-1, 28, 28, 1)
X_test = X_test.reshape(-1, 28, 28, 1)
print('X_train shape:', X_train.shape)
print('X_test shape:', X_test.shape)
import numpy as np
noise_factor = 0.5
noise = noise_factor * np.random.normal(loc= 0.5, scale= 0.5, size= X_train.shape)
X_train_noisy = X_train + noise
X_train_noisy = np.clip(X_train_noisy, 0., 1.)
noise = noise_factor * np.random.normal(loc= 0.5, scale= 0.5, size= X_test.shape)
X_test_noisy = X_test + noise
X_test_noisy = np.clip(X_test_noisy, 0., 1.)
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Activation
from tensorflow.keras.models import Model
import warnings; warnings.filterwarnings('ignore')
input_img = Input(shape=(28,28,1), name='Encoder_input')
# Encoder
Enc = Conv2D(16, (3, 3), padding='same', activation='relu', name='Enc_conv2d_1')(input_img)
Enc = MaxPooling2D(pool_size=(2,2), padding='same', name='Enc_max_pooling2d_1')(Enc)
Enc = Conv2D(8,(3, 3), padding='same', activation='relu', name='Enc_conv2d_2')(Enc)
Enc = MaxPooling2D(pool_size=(2,2), padding='same', name='Enc_max_pooling2d_2')(Enc)
Encoded = Conv2D(1, (3, 3), padding='same', activation='sigmoid', name='Enc_conv2d_3')(Enc)
# Instantiate the Encoder Model
encoder = Model(inputs = input_img, outputs = Encoded)
encoder.summary()
# Decoder
Dec = Conv2D(8, (3, 3), padding='same', activation='relu', name ='Dec_conv2d_1')(Encoded)
Dec = UpSampling2D((2, 2), name = 'Dec_upsampling2d_1')(Dec)
Dec = Conv2D(16, (3, 3), padding='same', activation='relu', name ='Dec_conv2d_2')(Dec)
Dec = UpSampling2D((2, 2), name = 'Dec_upsampling2d_2')(Dec)
decoded = Conv2D(1,(3, 3), padding='same', activation='sigmoid', name ='Dec_conv2d_3')(Dec)
# Instantiate the Autoencoder Model
autoencoder = Model(inputs = input_img, outputs = decoded)
autoencoder.summary()
# Compile the autoencoder
autoencoder.compile(loss='mse', optimizer='adam')
# Train the autoencoder
history = autoencoder.fit(X_train_noisy, X_train,
epochs = 10,
batch_size = 20,
shuffle = True,
validation_split = 1/6).history
import matplotlib.pyplot as plt
%matplotlib inline
tr_loss = history['loss']
val_loss = history['val_loss']
epochs = range(1, len(tr_loss)+1)
fig = plt.figure(figsize=(8, 4))
fig.tight_layout()
plt.plot(epochs, tr_loss,'r')
plt.plot(epochs, val_loss,'b')
plt.title('Model loss')
plt.ylabel('MSE')
plt.xlabel('Epoch')
plt.legend(['Training', 'Validation'], loc='upper right')
plt.show()
from keras.models import model_from_json
# Save the Encoder
model_json = encoder.to_json()
with open("logs/Encoder_model.json", "w") as json_file:
json_file.write(model_json)
encoder.save_weights("logs/Encoder_weights.h5")
# Save the Autoencoder
model_json = autoencoder.to_json()
with open("logs/Autoencoder_model.json", "w") as json_file:
json_file.write(model_json)
autoencoder.save_weights("logs/Autoencoder_weights.h5")
from tensorflow.keras.models import model_from_json
import warnings; warnings.filterwarnings('ignore')
with open('logs/Encoder_model.json', 'r') as f:
Myencoder = model_from_json(f.read())
Myencoder.load_weights("logs/Encoder_weights.h5")
with open('logs/Autoencoder_model.json', 'r') as f:
MyAutoencoder = model_from_json(f.read())
MyAutoencoder.load_weights("logs/Autoencoder_weights.h5")
# Pick randomly some images from test set
num_images = 10
random_test_images = np.random.randint(X_test.shape[0], size= num_images)
# Predict the Encoder and the Autoencoder outputs from the noisy test images
encoded_imgs = Myencoder.predict(X_test_noisy)
decoded_imgs = MyAutoencoder.predict(X_test_noisy)
plt.figure(figsize=(20, 10))
fig.tight_layout()
for i, image_idx in enumerate(random_test_images):
# Plot original image
ax = plt.subplot(4, num_images, i + 1)
plt.imshow(X_test[image_idx].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
if i == num_images//2:
ax.set_title('Original Images')
# Plot noised image
ax = plt.subplot(4, num_images, num_images + i + 1)
plt.imshow(X_test_noisy[image_idx].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
if i == num_images//2:
ax.set_title('Noised Images')
# Plot encoded image
ax = plt.subplot(4, num_images, 2*num_images + i + 1)
plt.imshow(encoded_imgs[image_idx].reshape(7, 7))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
if i == num_images//2:
ax.set_title('Encoded Images')
# Plot reconstructed image
ax = plt.subplot(4, num_images, 3*num_images + i + 1)
plt.imshow(decoded_imgs[image_idx].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
if i == num_images//2:
ax.set_title('Denoised Images')
plt.show()