Variational Autoencoder (DRAFT)

The Variational Autoencoder (VAE) is the successor to the regular autoencoder (AE). As you probably already know, a regular AE is a neural network that compresses its input to a latent space and then reconstructs the orignal input from that latent space. In drawings, you would often find something like this:

Standard Autoencoder

This is simple stuff, and you can easily create such a network. Here is a quick example of what an AE looks in code:

class Autoencoder(eqx.Module):
    encoder: eqx.nn.Sequential
    decoder: eqx.nn.Sequential

    def __init__(self, dim: int, hidden_dim: int, z_dim: int, key: PRNGKeyArray):
        key, *subkeys = jax.random.split(key, 10)
        self.encoder = eqx.nn.Sequential(
            [
                eqx.nn.Linear(dim, hidden_dim, key=subkeys[0]),
                eqx.nn.Lambda(fn=jax.nn.relu),
                eqx.nn.Linear(hidden_dim, z_dim, key=subkeys[2]),
            ]
        )

        self.decoder = eqx.nn.Sequential(
            [
                eqx.nn.Linear(z_dim, hidden_dim, key=subkeys[2]),
                eqx.nn.Lambda(fn=jax.nn.relu),
                eqx.nn.Linear(hidden_dim, dim, key=subkeys[2]),
            ]
        )

    def __call__(self, x: Array) -> tuple[Array, Array]:
        z = self.encoder(x)
        o = self.decoder(z)
        return o, z

Very simple stuff. Ok, but let’s say you have trained your AE on MNIST and you get get some nice reconstructions like these:

Reconstructions

One thing you might be thinking is this: if I give my AE an image of a 11 and I get some vector z1z_1 back and then I encode an image of a 22 and get another vector z2z_2 back, then what does the middle point between z1z_1 and z2z_2 look like? After all, I can put either vector zz into my decoder and get a nice reconstructed image of my original input back. Does this mean that the middle point between the encoded vectors for the image 11 and 22 is an image which looks kind of like both 11 and 22? If it wasn’t numbers but faces, can I give the AE 2 faces, take the middle point and put that through the decoder to get a completely new face back?

The unfortunate truth is: no!

But if we could somehow tidy up the latent space, then yes, we could generate new and authentic looking images. And the way we can tidy the latent space up is by using VAE. VAEs solve this by forcing the latent space to follow a specific distribution (Gaussian), which creates a smooth, organized latent space where interpolation works!

In a VAE, we have 2 spaces: the data space p(x)p(x) and the latent space p(z)p(z) and we don’t really have access to any of those (we only have a bunch of data points sampled from p(x)p(x) but that’s about it). VAE has a design decision and says that p(z)p(z) is normal distributed and this will be very useful later in the loss function derivation.

Between these, we have 2 mappings (both also normal distributions) that map one space to the other which are:

p(xz)(kind of decoder)p(zx)(kind of encoder)\begin{align*} & p(x|z) \qquad \text{(kind of decoder)} \\ & p(z|x) \qquad \text{(kind of encoder)} \end{align*}

These mappings are like our encoder and decoder: p(xz)p(x|z) generates (or reconstructs) xx from a latent vector zz and vice versa.

And we don’t really know those either (at least so far).

The decoder we can just learn as a supervised learning task and is by far the easiest part: we have the input zz and the target xx and all we need to do is to compare the output of the decoder against the input xx and we’re golden.

But the encoder is a different story, because

p(zx)=p(xz)p(z)p(x)p(x)=p(xz)p(z)dz\begin{align*} p(z|x) &= \frac{p(x|z)p(z)}{p(x)} \\ p(x) &= \int p(x|z)p(z)dz \end{align*}

which would mean we would have to integrate over the entire latent space p(z)p(z), which is computationally not feasible. So, instead, we will approximate p(zx)p(z|x) with

q(zx)p(zx) q(z|x) \approx p(z|x)

And we do our training correctly, then q(zx)q(z|x) will indeed be a good approximation to the true encoder and q(zx)q(z|x) will also be a normal distribution. And btw. this means that it will output a μ\mu and a σ\sigma, which we can use to sample the vector zz.

Deriving the Loss Function

Here’s the goal: we want to maximise logp(x)log p(x) for each xx in our dataset and if you’re asking why, then you’re in good company, because that’s not immediately obvious. Remember how I said we have 2 distributions and that one of them is p(x)p(x) - the data distribution? What logp(x)log p(x) tells us is the probability that we could sample xx from the distribution. But all our data points xx ARE from the data distribution, so the probability is 1.01.0, because we don’t have any datapoints outside of our dataset. So that is our starting point:

maximiselogp(x) \text{maximise} \qquad \log p(x)

The first thing we can say is this:

logp(x)=logp(xz)p(z)dz \log p(x) = \log \int p(x|z)p(z) dz

And that means marginalising out zz. To better understand this, imagine you had 2 die (an xx-dice and a zz-dice) and their probabilities are skewed such that higher numbers have higher probabilities. A probability matrix would look like this:

Probability Matrix

The redder areas indicate a higher probability. If you were interested in the probability that p(x=5)p(x=5), then, to calculate that, you need to compute the sum of all the outcomes where p(x=5,z=any)p(x=5, z=any) (that’s the one I highlighted in the image), so, in other words, it’s:

p(x=5)=i=16p(x=5,z=i)p(x=5)=i=16p(x=5z=i)p(z=i) \begin{align*} p(x=5) &= \sum_{i=1}^6 p(x=5, z=i) p(x=5) &= \sum_{i=1}^6 p(x=5|z=i) p(z=i) \end{align*}

This process is called marginalization. It’s essentially the same as logp(x)=logp(xz)p(z)dz\log p(x) = \log \int p(x|z) p(z) dz, except of course there we are dealing with continious values (and the log is there for numerical stability, but doesn’t change the probabilities underneath).

We use the “multiply by one” trick to introduce a new term:

logp(x)=log(p(xz)p(z)dz)=log(q(zx)q(zx)p(xz)p(z)dz)=log(q(zx)q(zx)p(x,z)dz) \begin{align*} \log p(x) &= \log \left( \int p(x|z)p(z) dz \right) \\ &= \log \left( \int \frac{q(z|x)}{q(z|x)} p(x|z)p(z) dz \right) \\ &= \log \left( \int \frac{q(z|x)}{q(z|x)} p(x,z) dz \right) \end{align*}

Pretty neat, now we introduced our approximation. We can rearrange some stuff to get this:

logp(x)=log(p(xz)p(z)dz)=log(q(zx)q(zx)p(xz)p(z)dz)=log(q(zx)q(zx)p(x,z)dz)=log(q(zx)p(x,z)q(zx)dz) \begin{align*} \log p(x) &= \log \left( \int p(x|z)p(z) dz \right) \\ &= \log \left( \int \frac{q(z|x)}{q(z|x)} p(x|z)p(z) dz \right) \\ &= \log \left( \int \frac{q(z|x)}{q(z|x)} p(x,z) dz \right) \\ &= \log \left( \int q(z|x) \frac{p(x,z)}{q(z|x)} dz \right) \end{align*}

The next trick is a bit unintuitive, but bear with me. Let’s say you have two functions:

f(z)=z f(z) = z

and another function which generates the zz randomly

Q(z) Q(z)

Think back to the skewed die from earlier. If Q(z)Q(z) is the random outcome generator for one of those die, then it will return higher values for zz with greater probability than lower ones. So what is the expected value for the function f(z)f(z) in this case? It is defined as:

EQ(z)f(z)=Q(z)f(z)dz \mathbb{E}_{Q(z)} f(z) = \int Q(z) f(z) dz

Or spoken in plain English: the expected value for the function f(z)f(z) is the probability to sample a zz times the value of that zz. For our die example, we could say that we have a 10% change to roll a 1 and a 90% chance to roll a 6 and nothing else. In this case, the expected value for f(z)f(z) is

f(z)=zf(1)=1f(6)=6Q(1)=0.1Q(6)=0.9E=Q(1)f(1)+Q(6)f(6)=0.11+0.96=5.5 \begin{align*} f(z) &= z \\ f(1) &= 1 \\ f(6) &= 6 \\ Q(1) &= 0.1 \\ Q(6) &= 0.9 \\ \mathbb{E} &= Q(1) * f(1) + Q(6) * f(6) \\ &= 0.1 * 1 + 0.9 * 6 \\ &= 5.5 \end{align*}

We have the same setting in our derivation:

logp(x)=log(p(xz)p(z)dz)=log(q(zx)q(zx)p(xz)p(z)dz)=log(q(zx)q(zx)p(x,z)dz)=log(q(zx)p(x,z)q(zx)dz) \begin{align*} \log p(x) &= \log \left( \int p(x|z)p(z) dz \right) \\ &= \log \left( \int \frac{q(z|x)}{q(z|x)} p(x|z)p(z) dz \right) \\ &= \log \left( \int \frac{q(z|x)}{q(z|x)} p(x,z) dz \right) \\ &= \log \left( \int q(z|x) \frac{p(x,z)}{q(z|x)} dz \right) \end{align*}

Where q(zx)q(z|x) is the probability to sample zz (this is akin to the Q(z) from the definition earlier) and p(x,z)q(zx)\frac{p(x,z)}{q(z|x)} is the value function (the f(z) in the example). With that, we can rewrite it like so:

logp(x)=log(p(xz)p(z)dz)=log(q(zx)q(zx)p(xz)p(z)dz)=log(q(zx)q(zx)p(x,z)dz)=log(q(zx)p(x,z)q(zx)dz)=logEq(zx)(p(x,z)q(zx)) \begin{align*} \log p(x) &= \log \left( \int p(x|z)p(z) dz \right) \\ &= \log \left( \int \frac{q(z|x)}{q(z|x)} p(x|z)p(z) dz \right) \\ &= \log \left( \int \frac{q(z|x)}{q(z|x)} p(x,z) dz \right) \\ &= \log \left( \int q(z|x) \frac{p(x,z)}{q(z|x)} dz \right) \\ &= \log \mathbb{E}_{q(z|x)} \left( \frac{p(x,z)}{q(z|x)} \right) \end{align*}

The next trick we can use is the Jensen inequality, which states:

f(E(y))E(f(y))f(E(y)) \ge E(f(y))

if ff is a concave function and since log\log is a concave function, we can say

log(E(y))E(log(y))\log(E(y)) \ge E(\log(y))

For our derivation, we can now write:

logp(x)=log(p(xz)p(z)dz)=log(q(zx)q(zx)p(xz)p(z)dz)=log(q(zx)q(zx)p(x,z)dz)=log(q(zx)p(x,z)q(zx)dz)=logEq(zx)(p(x,z)q(zx))Eq(zx)(logp(x,z)q(zx)) \begin{align*} \log p(x) &= \log \left( \int p(x|z)p(z) dz \right) \\ &= \log \left( \int \frac{q(z|x)}{q(z|x)} p(x|z)p(z) dz \right) \\ &= \log \left( \int \frac{q(z|x)}{q(z|x)} p(x,z) dz \right) \\ &= \log \left( \int q(z|x) \frac{p(x,z)}{q(z|x)} dz \right) \\ &= \log \mathbb{E}_{q(z|x)} \left( \frac{p(x,z)}{q(z|x)} \right) \\ &\ge \mathbb{E}_{q(z|x)} \left( \log \frac{p(x,z)}{q(z|x)} \right) \end{align*}

This is great, because now if we can - somehow - increase Eq(zx)(logp(x,z)q(zx))\mathbb{E}_{q(z|x)} \left( \log \frac{p(x,z)}{q(z|x)} \right), then it will automatically raise the bar for logp(x)\log p(x).

Because p(x,z)=p(xz)p(z)p(x,z) = p(x|z)p(z), we can write and rearrange the terms like so:

logp(x)=log(p(xz)p(z)dz)=log(q(zx)q(zx)p(xz)p(z)dz)=log(q(zx)q(zx)p(x,z)dz)=log(q(zx)p(x,z)q(zx)dz)=logEq(zx)(p(x,z)q(zx))Eq(zx)(logp(x,z)q(zx))Eq(zx)(logp(xz)p(z)q(zx)) \begin{align*} \log p(x) &= \log \left( \int p(x|z)p(z) dz \right) \\ &= \log \left( \int \frac{q(z|x)}{q(z|x)} p(x|z)p(z) dz \right) \\ &= \log \left( \int \frac{q(z|x)}{q(z|x)} p(x,z) dz \right) \\ &= \log \left( \int q(z|x) \frac{p(x,z)}{q(z|x)} dz \right) \\ &= \log \mathbb{E}_{q(z|x)} \left( \frac{p(x,z)}{q(z|x)} \right) \\ &\ge \mathbb{E}_{q(z|x)} \left( \log \frac{p(x,z)}{q(z|x)} \right) \\ &\ge \mathbb{E}_{q(z|x)} \left( \log \frac{p(x|z)p(z)}{q(z|x)} \right) \end{align*}

The laws of the logs tell us:

log(xy)=logx+logylog(x/y)=logxlogy \begin{align*} \log(xy) &= \log x + \log y \\ \log(x/y) &= \log x - \log y \end{align*}

And because of that, we can rewrite the term as:

logp(x)=log(p(xz)p(z)dz)=log(q(zx)q(zx)p(xz)p(z)dz)=log(q(zx)q(zx)p(x,z)dz)=log(q(zx)p(x,z)q(zx)dz)=logEq(zx)(p(x,z)q(zx))Eq(zx)(logp(x,z)q(zx))Eq(zx)(logp(xz)p(z)q(zx))Eq(zx)(logp(xz)+logp(z)q(zx))Eq(zx)(logp(xz)+logp(z)logq(zx))Eq(zx)(logp(xz))+Eq(zx)(logp(z)logq(zx)) \begin{align*} \log p(x) &= \log \left( \int p(x|z)p(z) dz \right) \\ &= \log \left( \int \frac{q(z|x)}{q(z|x)} p(x|z)p(z) dz \right) \\ &= \log \left( \int \frac{q(z|x)}{q(z|x)} p(x,z) dz \right) \\ &= \log \left( \int q(z|x) \frac{p(x,z)}{q(z|x)} dz \right) \\ &= \log \mathbb{E}_{q(z|x)} \left( \frac{p(x,z)}{q(z|x)} \right) \\ &\ge \mathbb{E}_{q(z|x)} \left( \log \frac{p(x,z)}{q(z|x)} \right) \\ &\ge \mathbb{E}_{q(z|x)} \left( \log \frac{p(x|z)p(z)}{q(z|x)} \right) \\ &\ge \mathbb{E}_{q(z|x)} \left( \log p(x|z) + \log \frac{p(z)}{q(z|x)} \right) \\ &\ge \mathbb{E}_{q(z|x)} \left( \textcolor{blue}{\log p(x|z)} + \textcolor{green}{\log p(z) - \log q(z|x)} \right) \\ &\ge \mathbb{E}_{q(z|x)}(\log p(x|z)) + \mathbb{E}_{q(z|x)}(\log p(z) - \log q(z|x)) \\ \end{align*}

The blue part trains the decoder, while the green part trains the encoder. Furthermore, the blue part will simplify to the MSE, while the green part is the exact definition of the negative KL divergence. Let’s start with the decoder part, because that’s a bit easier.

I said earlier that the encoder outputs a μ\mu and a σ\sigma which we use to sample the latent vector zz. The decoder technically also outputs both of these, but in practice, we set σ\sigma to a constant and use μ\mu directly. Because the decoder is a normal distribution, we can write this:

p(xz)=1(2πσ2)D/2exp(xxrec(z)22σ2)logp(xz)=log(1(2πσ2)D/2)xxrec(z)22σ2\begin{align*} p(x|z) &= \frac{1}{(2\pi\sigma^2)^{D/2}} \exp\left(-\frac{\|x - x_{rec}(z)\|^2}{2\sigma^2}\right) \\ \log p(x|z) &= \log\left( \frac{1}{(2\pi\sigma^2)^{D/2}} \right) - \frac{\|x - x_{rec}(z)\|^2}{2\sigma^2} \end{align*}

When it comes to optimization, everything that is a constant, we don’t care about. This means that the only thing that does remain and is NOT constant is:

(xxrec(z))2{(x - x_{rec}(z))}^2

Which is the mean squared error (and xrecx_{rec} is the output of our decoder). Now, let’s have a look at the encoder:

Eq(zx)(logp(z)logq(zx))\mathbb{E}_{q(z|x)}(\log p(z) - \log q(z|x))

Which is precisely the definition for the KL divergence, and because p(z)p(z) is a normal distribution, the KL divergence simplifies to a closed form:

DKL(q(zx)p(z))=logσzσe+σe2+(μeμz)22σz212=log1σe+σe2+(μe0)221212=logσe+σe2+μe2212=12log(σe2)+σe2+μe212=12(μe2+σe2log(σe2)1)=12(1+log(σe2)μe2σe2)\begin{align*} D_{KL}(q(z|x) \,||\, p(z)) &= \log\frac{\sigma_z}{\sigma_e} + \frac{\sigma_e^2 + (\mu_e - \mu_z)^2}{2\sigma_z^2} - \frac{1}{2} \\ &= \log\frac{1}{\sigma_e} + \frac{\sigma_e^2 + (\mu_e - 0)^2}{2 \cdot 1^2} - \frac{1}{2} \\ &= -\log\sigma_e + \frac{\sigma_e^2 + \mu_e^2}{2} - \frac{1}{2} \\ &= -\frac{1}{2}\log(\sigma_e^2) + \frac{\sigma_e^2 + \mu_e^2 - 1}{2} \\ &= \frac{1}{2} \left( \mu_e^2 + \sigma_e^2 - \log(\sigma_e^2) - 1 \right) \\ &= -\frac{1}{2} \left( 1 + \log(\sigma_e^2) - \mu_e^2 - \sigma_e^2 \right) \end{align*}

μz\mu_z and σz\sigma_z come from p(z)p(z) and because p(z)p(z) is a Gaussian, those are 00 and 11 respectively and μe\mu_e and σe\sigma_e come from q(zx)q(z|x) (i.e. the encoder approximation).

So, if we put everything together, we get:

logp(x)=log(p(xz)p(z)dz)=log(q(zx)q(zx)p(xz)p(z)dz)=logq(zx)q(zx)p(xz)p(z)dz=logq(zx)p(xz)p(z)q(zx)dz=logEq(zx)[p(xz)p(z)q(zx)]Eq(zx)[logp(xz)p(z)q(zx)](Jensen’s inequality)=Eq(zx)[logp(xz)+logp(z)logq(zx)]=Eq(zx)[logp(xz)]+Eq(zx)[logp(z)logq(zx)]=Eq(zx)[logp(xz)]DKL(q(zx)p(z))(this is the ELBO)logp(xz)=12xxrec(z)2d2log(2π)DKL(q(zx)p(z))=12(μe2+σe2logσe21)Loss=ELBO=Eq(zx)[logp(xz)]+DKL(q(zx)p(z))=Eq(zx)[12xxrec(z)2d2log(2π)]+DKL(q(zx)p(z))=Eq(zx)[12xxrec(z)2+d2log(2π)]+DKL(q(zx)p(z))=12Eq(zx)[xxrec(z)2]+12(μe2+σe2logσe21)+constL=Eq(zx)[xxrec(z)2]+12(μe2+σe2logσe21)\begin{align*} \log p(x) &= \log \left( \int p(x|z)p(z) dz \right) \\ &= \log \left( \int \frac{q(z|x)}{q(z|x)} p(x|z)p(z) dz \right) \\ &= \log \int \frac{q(z|x)}{q(z|x)} p(x|z)p(z) dz \\ &= \log \int q(z|x) \frac{p(x|z)p(z)}{q(z|x)} dz \\ &= \log \mathbb{E}_{q(z|x)} \left[ \frac{p(x|z)p(z)}{q(z|x)} \right] \\ &\geq \mathbb{E}_{q(z|x)} \left[ \log \frac{p(x|z)p(z)}{q(z|x)} \right] \quad \text{(Jensen's inequality)} \\ &= \mathbb{E}_{q(z|x)} \left[ \log p(x|z) + \log p(z) - \log q(z|x) \right] \\ &= \mathbb{E}_{q(z|x)}[\log p(x|z)] + \mathbb{E}_{q(z|x)}[\log p(z) - \log q(z|x)] \\ &= \mathbb{E}_{q(z|x)}[\log p(x|z)] - D_{KL}(q(z|x) \,||\, p(z)) \quad \text{(this is the ELBO)}\\ \log p(x|z) &= -\frac{1}{2}||x - x_{rec}(z)||^2 - \frac{d}{2}\log(2\pi) \\ D_{KL}(q(z|x) \,||\, p(z)) &= \frac{1}{2}\sum(\mu_e^2 + \sigma_e^2 - \log \sigma_e^2 - 1) \\ \text{Loss} &= -\text{ELBO} \\ &= -\mathbb{E}_{q(z|x)}[\log p(x|z)] + D_{KL}(q(z|x) \,||\, p(z)) \\ &= -\mathbb{E}_{q(z|x)}\left[-\frac{1}{2}||x - x_{rec}(z)||^2 - \frac{d}{2}\log(2\pi)\right] + D_{KL}(q(z|x) \,||\, p(z)) \\ &= \mathbb{E}_{q(z|x)}\left[\frac{1}{2}||x - x_{rec}(z)||^2 + \frac{d}{2}\log(2\pi)\right] + D_{KL}(q(z|x) \,||\, p(z)) \\ &= \frac{1}{2}\mathbb{E}_{q(z|x)}[||x - x_{rec}(z)||^2] + \frac{1}{2}\sum(\mu_e^2 + \sigma_e^2 - \log \sigma_e^2 - 1) + \text{const} \\ \mathcal{L} &= \mathbb{E}_{q(z|x)}[||x - x_{rec}(z)||^2] + \frac{1}{2}\sum(\mu_e^2 + \sigma_e^2 - \log \sigma_e^2 - 1) \end{align*}

The +const+\text{const} part refers to this part d2log(2π)\frac{d}{2}\log(2\pi) which is just some constant and those don’t matter when we minimise in our ML libraries and the same goes for scalars, which is why we dropped the 1/21/2 too.

The Code

With the loss function in place, the rest is actually pretty simple. Here is the full code:

import clu.metrics as clum
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
from tqdm import tqdm


class LossMetrics(eqx.Module, clum.Collection):
    loss: clum.Average.from_output("loss")  # pyright: ignore


tf.random.set_seed(42)
np.random.seed(42)

(train_ds, test_ds), ds_info = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)  # pyright: ignore


def preprocess(image, label):
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.reshape(image, (-1,))
    return image, label


BATCH_SIZE = 128

train_ds = train_ds.map(preprocess)
train_ds = train_ds.cache()
train_ds = train_ds.shuffle(ds_info.splits["train"].num_examples)
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)

test_ds = test_ds.map(preprocess)
test_ds = test_ds.batch(BATCH_SIZE)
test_ds = test_ds.prefetch(tf.data.AUTOTUNE)


def reparameterize(mu: Array, logvar: Array, key: PRNGKeyArray) -> Array:
    std = jnp.exp(0.5 * logvar)
    eps = jax.random.normal(key, shape=mu.shape)
    return mu + std * eps


class Encoder(eqx.Module):
    layers: eqx.nn.Sequential

    mu_layer: eqx.nn.Linear
    logvar_layer: eqx.nn.Linear

    def __init__(
        self, input_dim: int, hidden_dim: int, latent_dim: int, key: PRNGKeyArray
    ):
        key, *subkeys = jax.random.split(key, 10)

        self.layers = eqx.nn.Sequential(
            [
                eqx.nn.Linear(input_dim, hidden_dim, key=subkeys[0]),
                eqx.nn.Lambda(jax.nn.relu),
                eqx.nn.Linear(hidden_dim, hidden_dim, key=subkeys[1]),
                eqx.nn.Lambda(jax.nn.relu),
            ]
        )

        self.mu_layer = eqx.nn.Linear(hidden_dim, latent_dim, key=subkeys[2])
        self.logvar_layer = eqx.nn.Linear(hidden_dim, latent_dim, key=subkeys[3])

    def __call__(self, x: Float[Array, " input_dim"]) -> tuple[Array, Array]:
        x = self.layers(x)
        mu = self.mu_layer(x)
        logvar = self.logvar_layer(x)

        return mu, logvar


class Decoder(eqx.Module):
    layers: eqx.nn.Sequential

    def __init__(
        self, latent_dim: int, hidden_dim: int, output_dim: int, key: PRNGKeyArray
    ):
        key, *subkeys = jax.random.split(key, 10)

        self.layers = eqx.nn.Sequential(
            [
                eqx.nn.Linear(latent_dim, hidden_dim, key=subkeys[0]),
                eqx.nn.Lambda(jax.nn.relu),
                eqx.nn.Linear(hidden_dim, hidden_dim, key=subkeys[1]),
                eqx.nn.Lambda(jax.nn.relu),
                eqx.nn.Linear(hidden_dim, output_dim, key=subkeys[2]),
            ]
        )

    def __call__(self, z: Float[Array, "latent_dim"]) -> Array:
        x = self.layers(z)
        return x


class VAE(eqx.Module):
    encoder: Encoder
    decoder: Decoder

    def __init__(
        self, input_dim: int, hidden_dim: int, latent_dim: int, key: PRNGKeyArray
    ):
        key, *subkeys = jax.random.split(key, 5)
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim, key=subkeys[0])
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim, key=subkeys[1])

    def __call__(self, x: Array, key: PRNGKeyArray):
        mu, logvar = self.encoder(x)
        z = reparameterize(mu, logvar, key)

        reconstructed_x = self.decoder(z)
        return reconstructed_x, mu, logvar


def vae_loss(
    model: VAE, x: Float[Array, "batch_size input_dim"], key: PRNGKeyArray
) -> Array:
    keys = jax.random.split(key, len(x))
    reconstructed_x, mu, logvar = eqx.filter_vmap(model)(x, keys)

    recon_loss = jnp.mean(jnp.sum(jnp.square(x - reconstructed_x), axis=-1))

    kl_loss = -0.5 * jnp.mean(
        jnp.sum(1 + logvar - jnp.square(mu) - jnp.exp(logvar), axis=-1)
    )

    total_loss = recon_loss + kl_loss
    return total_loss


vae_model = VAE(input_dim=784, hidden_dim=128, latent_dim=8, key=jax.random.key(44))

images, labels = next(iter(train_ds.take(1)))
images = jnp.array(images)
labels = jnp.array(labels)
loss = vae_loss(vae_model, images, key=jax.random.key(42))
print(loss)

learning_rate = 3e-4
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(eqx.filter(vae_model, eqx.is_array))


@eqx.filter_jit
def step(
    model: PyTree,
    x: Array,
    optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
    key: PRNGKeyArray,
):
    loss, grads = eqx.filter_value_and_grad(vae_loss)(model, x, key)
    updates, opt_state = optimizer.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)

    return loss, model, opt_state


n_epochs = 500

metrics = LossMetrics.empty()
key = jax.random.key(0)
for epoch in tqdm(range(n_epochs)):
    for images, _ in train_ds:
        images = jnp.array(images)
        key, subkey = jax.random.split(key)
        loss, vae_model, opt_state = step(
            vae_model, images, optimizer, opt_state, subkey
        )
        metrics = metrics.merge(LossMetrics.single_from_model_output(loss=loss))

    print(f"Epoch {epoch}, loss {metrics.compute()}")


eqx.tree_serialise_leaves("mnist_vae", vae_model)

vae_model = eqx.tree_deserialise_leaves("mnist_vae", vae_model)


def visualize_reconstructions(model, images, n=10):
    # Get the first n images
    original = images[:n]

    # Reconstruct images
    key = jax.random.key(0)
    keys = jax.random.split(key, n)
    reconstructed, _, _ = eqx.filter_vmap(model)(original, keys)

    # Reshape for visualization
    original = original.reshape(-1, 28, 28)
    reconstructed = reconstructed.reshape(-1, 28, 28)

    # Plot
    fig, axes = plt.subplots(2, n, figsize=(n, 2))
    for i in range(n):
        axes[0, i].imshow(original[i], cmap="gray")
        axes[0, i].axis("off")
        axes[1, i].imshow(reconstructed[i], cmap="gray")
        axes[1, i].axis("off")

    plt.tight_layout()
    plt.savefig("reconstructions.png")
    plt.close()


# Generate random samples from the latent space
def visualize_samples(model, n=10):
    # Sample from the latent space
    key = jax.random.key(42)
    z = jax.random.normal(key, shape=(n, model.encoder.mu_layer.out_features))

    # Decode the samples
    samples = eqx.filter_vmap(model.decoder)(z)

    # Reshape for visualization
    samples = samples.reshape(-1, 28, 28)

    # Plot
    fig, axes = plt.subplots(1, n, figsize=(n, 1))
    for i in range(n):
        axes[i].imshow(samples[i], cmap="gray")
        axes[i].axis("off")

    plt.tight_layout()
    plt.savefig("samples.png")
    plt.close()


# Call these after training
test_images, _ = next(iter(test_ds))
test_images = jnp.array(test_images)
visualize_reconstructions(vae_model, test_images)
visualize_samples(vae_model)

If you run this code, you will get this neat images:

Reconstructions

These are the reconstructions from the VAE. You can compare those to the reconstructions from the normal AE:

Reconstructions AE

And now the real kicker, which are the newly generated samples:

Samples

One thing that needs more context is this function:

def reparameterize(mu: Array, logvar: Array, key: PRNGKeyArray) -> Array:
    std = jnp.exp(0.5 * logvar)
    eps = jax.random.normal(key, shape=mu.shape)
    return mu + std * eps

And why we use this function in the VAE class instead of directly sampling using mu and logvar from the encoder, which would look like this:

np.random.normal(loc=mu, scale=logvar, size=10) # in Numpy

Well, for one, we can’t even pass the mean and variance to the jax.random.normal function, and even if we could, what is the gradient in that case? Well, there is no gradient; it would be interrupted. That’s why we use this reparameterize-trick instead.