Using downloaded and verified file: /home/bcaffo/.medmnist/chestmnist.npz
Using downloaded and verified file: /home/bcaffo/.medmnist/chestmnist.npz
Since we’re talking about variational autoencoders, we should probably first talk about variational Bayes (VB). VB is a scaling solution for Bayesian inference.
Bayesian inference tends to follow something like this. We have a likelihood,
Many solutions use candidate distributions. Say, for example,
Importance sampling, Gibbs sampling and other Monte Carlo techniques get combined for complicated problems. However, modern problems often present a challenge that Monte Carlo techniques cannot solve. Enter variational Bayes. Instead of fixing up samples from
This is where variational Bayes comes in. It turns a MC sampling problem into an optimization problem. Wikipedia has a pretty nice introduction to the topic. The most common version of variational Bayes uses the KL divergence. I.e. choose
Often
It is called variational Bayes, since it uses calculus of variations to solve this equation. Go over the Wikipedia article’s Gaussian / Gamma prior example. There you can see the derivation of the variational posterior approximation in a case where it can be done analytically. In most cases, one is left with doing it via algorithmmic optimization.
Variational autoencoders were introduced in (Kingma and Welling 2013). A really good tutorial can be found here and some sample code on MNIST can be found here. An alternate way to think about autoencoders is via variational Bayes arguments. Let
We could view any latent probability distribution as an autoencoder, where
One issue with this approach is that computing is quite hard for problems of sufficient scale. Variational Bayes uses approximations instead of the actual distributions. Let
We can rewrite the ELBO as (try for HW!):
Consider the following assumptions:
We can write
Both of these expressions are the contributions of one row,
We’re omitting our imports and gettng the data. See the qmd file for the full list. Here we have epochs as 10 and batch size as 128.
Using downloaded and verified file: /home/bcaffo/.medmnist/chestmnist.npz
Using downloaded and verified file: /home/bcaffo/.medmnist/chestmnist.npz
Here’s a montage of the training data.
/home/bcaffo/miniconda3/envs/ds4bio/lib/python3.10/site-packages/medmnist/utils.py:25: FutureWarning:
`multichannel` is a deprecated argument name for `montage`. It will be removed in version 1.0. Please use `channel_axis` instead.
Here’s the shape of the dataset. Note,
= next(iter(train_dataset))
temp, _ = temp.numpy()
temp max(), temp.min()] [temp.shape, temp.
[(1, 28, 28), 0.93333334, 0.015686275]
Let’s use Jackson Kang’s wonderful VAE tutorial from here.
= 784 ## This is 28**2, the length of the vectorized images
x_dim = 400
hidden_dim = 200 latent_dim
First we define the encoder. Remember, the encoder is
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(Encoder, self).__init__()
self.FC_input = nn.Linear(input_dim, hidden_dim)
self.FC_input2 = nn.Linear(hidden_dim, hidden_dim)
self.FC_mean = nn.Linear(hidden_dim, latent_dim)
self.FC_var = nn.Linear (hidden_dim, latent_dim)
self.LeakyReLU = nn.LeakyReLU(0.2)
self.training = True
def forward(self, x):
= self.LeakyReLU(self.FC_input(x))
h_ = self.LeakyReLU(self.FC_input2(h_))
h_ = self.FC_mean(h_)
mean = self.FC_var(h_)
log_var return mean, log_var
Next we define our decoder. Remember our decoder is
class Decoder(nn.Module):
def __init__(self, latent_dim, hidden_dim, output_dim):
super(Decoder, self).__init__()
self.FC_hidden = nn.Linear(latent_dim, hidden_dim)
self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim)
self.FC_output = nn.Linear(hidden_dim, output_dim)
self.LeakyReLU = nn.LeakyReLU(0.2)
def forward(self, x):
= self.LeakyReLU(self.FC_hidden(x))
h = self.LeakyReLU(self.FC_hidden2(h))
h = torch.sigmoid(self.FC_output(h))
x_hat return x_hat
To run a datapoint,
class Model(nn.Module):
def __init__(self, Encoder, Decoder):
super(Model, self).__init__()
self.Encoder = Encoder
self.Decoder = Decoder
def forward(self, x):
= self.Encoder(x)
mean, log_var = torch.exp(0.5 * log_var)
sd = mean + sd * torch.randn_like(sd)
z = self.Decoder(z)
x_hat return x_hat, mean, log_var
= Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
encoder = Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = x_dim)
decoder = Model(Encoder=encoder, Decoder=decoder) model
The ELBO is the sum of the reproduction loss and the KL divergence.
def loss_function(x, x_hat, mean, log_var):
= nn.functional.mse_loss(x_hat, x, reduction='sum')
reproduction_loss = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())
KLD return reproduction_loss + KLD
= Adam(model.parameters(), lr=lr) optimizer
model.train()for epoch in range(NUM_EPOCHS):
for batch_idx, (x, _) in enumerate(train_loader):
## Note unlike the mnist data, the final batch isn't always the same size
= x.shape[0]
batch_size = x.view(batch_size, x_dim)
x
optimizer.zero_grad()= model(x)
x_hat, mean, log_var = loss_function(x, x_hat, mean, log_var)
loss
loss.backward() optimizer.step()
eval()
model.
#with torch.no_grad():
= next(iter(test_loader))
x, _ = x.shape[0]
batch_size = x.view(batch_size, x_dim)
x = model(x)
x_hat, _, _
= x.view(batch_size, 28, 28)
x = x_hat.view(batch_size, 28, 28)
x_hat
=[10, 5])
plt.figure(figsize
= 0
idx 1, 2, 1)
plt.subplot(='gray', vmin = 0, vmax = 1)
plt.imshow(x[idx].numpy(), cmap
plt.xticks([])
plt.yticks([])
1, 2, 2)
plt.subplot(='gray', vmin = 0, vmax = 1)
plt.imshow(x_hat[idx].detach().numpy(), cmap
plt.xticks([]) plt.yticks([])
([], [])
= 5
no_im = torch.randn(no_im ** 2, latent_dim)
noise = decoder(noise)
generated_images
= generated_images.view(no_im ** 2, 28, 28)
im
for i in range(no_im ** 2):
5, 5, i+1)
plt.subplot(='gray', vmin = 0, vmax = 1)
plt.imshow(im[i].detach().numpy(), cmap
plt.xticks([]) plt.yticks([])