|
import tensorflow as tf |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
model = tf.keras.models.load_model('VQ-VAE-Model') |
|
|
|
class VectorQuantizer(tf.keras.layers.Layer): |
|
def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs): |
|
super().__init__(**kwargs) |
|
self.embedding_dim = embedding_dim |
|
self.num_embeddings = num_embeddings |
|
self.beta = ( |
|
beta |
|
) |
|
|
|
|
|
w_init = tf.random_uniform_initializer() |
|
self.embeddings = tf.Variable( |
|
initial_value=w_init( |
|
shape=(self.embedding_dim, self.num_embeddings), dtype="float32" |
|
), |
|
trainable=True, |
|
name="embeddings_vqvae", |
|
) |
|
|
|
def call(self, x): |
|
|
|
|
|
input_shape = tf.shape(x) |
|
flattened = tf.reshape(x, [-1, self.embedding_dim]) |
|
|
|
|
|
encoding_indices = self.get_code_indices(flattened) |
|
encodings = tf.one_hot(encoding_indices, self.num_embeddings) |
|
quantized = tf.matmul(encodings, self.embeddings, transpose_b=True) |
|
quantized = tf.reshape(quantized, input_shape) |
|
|
|
|
|
|
|
|
|
|
|
commitment_loss = self.beta * tf.reduce_mean( |
|
(tf.stop_gradient(quantized) - x) ** 2 |
|
) |
|
codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2) |
|
self.add_loss(commitment_loss + codebook_loss) |
|
|
|
|
|
quantized = x + tf.stop_gradient(quantized - x) |
|
return quantized |
|
|
|
def get_code_indices(self, flattened_inputs): |
|
|
|
similarity = tf.matmul(flattened_inputs, self.embeddings) |
|
distances = ( |
|
tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True) |
|
+ tf.reduce_sum(self.embeddings ** 2, axis=0) |
|
- 2 * similarity |
|
) |
|
|
|
|
|
encoding_indices = tf.argmin(distances, axis=1) |
|
return encoding_indices |
|
|
|
vq_object = VectorQuantizer(64, 16) |
|
embs = np.load('embeddings.npy') |
|
vq_object.embeddings = embs |
|
encoder = model.layers[1] |
|
|
|
|
|
_, (x_test, _) = tf.keras.datasets.mnist.load_data() |
|
x_test = np.expand_dims(x_test, -1) |
|
x_test_scaled = (x_test / 255.0) - 0.5 |
|
|
|
def make_subplot_reconstruction(original, reconstructed): |
|
fig, axs = plt.subplots(3,2) |
|
for row_idx in range(3): |
|
axs[row_idx,0].imshow(original[row_idx].squeeze() + 0.5); |
|
axs[row_idx,0].axis('off') |
|
axs[row_idx,1].imshow(reconstructed[row_idx].squeeze() + 0.5); |
|
axs[row_idx,1].axis('off') |
|
|
|
axs[0,0].title.set_text("Original") |
|
axs[0,1].title.set_text("Reconstruction") |
|
plt.tight_layout() |
|
fig.set_size_inches(10, 10.5) |
|
return fig |
|
|
|
def make_subplot_latent(original, reconstructed): |
|
fig, axs = plt.subplots(3,2) |
|
for row_idx in range(3): |
|
axs[row_idx,0].matshow(original[row_idx].squeeze()); |
|
axs[row_idx,0].axis('off') |
|
|
|
axs[row_idx,1].matshow(reconstructed[row_idx].squeeze()); |
|
axs[row_idx,1].axis('off') |
|
for i in range(7): |
|
for j in range(7): |
|
c = reconstructed[row_idx][i,j] |
|
axs[row_idx,1].text(i, j, str(c), va='center', ha='center') |
|
|
|
axs[0,0].title.set_text("Original") |
|
axs[0,1].title.set_text("Discrete Latent Representation") |
|
plt.tight_layout() |
|
fig.set_size_inches(10, 10.5) |
|
return fig |
|
|
|
def plot_sample(mode): |
|
sample = np.random.choice(x_test.shape[0], 3) |
|
test_images = x_test_scaled[sample] |
|
if mode=='Reconstruction': |
|
reconstructions_test = model.predict(test_images) |
|
return make_subplot_reconstruction(test_images, reconstructions_test) |
|
encoded_out = encoder.predict(test_images) |
|
encoded = encoded_out.reshape(-1, encoded_out.shape[-1]) |
|
quant = vq_object.get_code_indices(encoded) |
|
quant = quant.numpy().reshape(encoded_out.shape[:-1]) |
|
|
|
return make_subplot_latent(test_images, quant) |
|
|
|
demo = gr.Blocks() |
|
|
|
with demo: |
|
gr.Markdown("# Vector-Quantized Variational Autoencoders (VQ-VAE)") |
|
gr.Markdown("""This space is to demonstrate the use of VQ-VAEs. Similar to tradiitonal VAEs, VQ-VAEs try to create a useful latent representation. |
|
However, VQ-VAEs latent space is **discrete** rather than continuous. Below, we can view how well this model compresses and reconstructs MNIST digits, but more importantly, we can see a |
|
discretized latent representation. These discrete representations can then be paired with a network like PixelCNN to generate novel images. |
|
|
|
VQ-VAEs are one of the tools used by DALL-E and are some of the only models that perform on par with VAEs but with a discrete latent space. |
|
For more information check out this [paper](https://arxiv.org/abs/1711.00937) and |
|
[example](https://keras.io/examples/generative/vq_vae/).<br> |
|
Full Credits for this example go to [Sayak Paul](https://twitter.com/RisingSayak).<br> |
|
Model card can be found [here](https://huggingface.co./brendenc/VQ-VAE).<br> |
|
Demo by [Brenden Connors](https://www.linkedin.com/in/brenden-connors-6a0512195)""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
radio = gr.Radio(choices=['Reconstruction','Discrete Latent Representation']) |
|
with gr.Row(): |
|
button = gr.Button('Run') |
|
with gr.Column(): |
|
out = gr.Plot() |
|
|
|
button.click(plot_sample, radio, out) |
|
|
|
demo.launch() |