Example of a Convolutional Variational Autoencoder (CVAE) on MNIST (#264)

* initial commit

* style fixes

* update of ACKNOWLEDGMENTS

* fixed comment

* minor refactoring; removed unused imports

* added cifar and cvae to top-level README.md

* removed mention of cuda/mps in argparse

* fixed training status output

* load_weights() with strict=True

* pretrained model update

* fixed imports and style

* requires mlx>=0.0.9

* updated with results using mlx 0.0.9

* removed mention of private repo

* simplify and combine to one file, more consistency with other exmaples

* few more nits

* nits

* spell

* format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Markus Enzweiler 2024-02-07 05:02:27 +01:00 committed by GitHub
parent 8071aacd98
commit 9b387007ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 531 additions and 2 deletions

View File

@ -11,3 +11,4 @@ MLX Examples was developed with contributions from the following individuals:
- Sarthak Yadav: Added the `cifar` and `speechcommands` examples. - Sarthak Yadav: Added the `cifar` and `speechcommands` examples.
- Shunta Saito: Added support for PLaMo models. - Shunta Saito: Added support for PLaMo models.
- Gabrijel Boduljak: Implemented `CLIP`. - Gabrijel Boduljak: Implemented `CLIP`.
- Markus Enzweiler: Added the `cvae` examples.

View File

@ -20,7 +20,9 @@ Some more useful examples are listed below.
### Image Models ### Image Models
- Image classification using [ResNets on CIFAR-10](cifar).
- Generating images with [Stable Diffusion](stable_diffusion). - Generating images with [Stable Diffusion](stable_diffusion).
- Convolutional variational autoencoder [(CVAE) on MNIST](cvae).
### Audio Models ### Audio Models

1
cvae/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
models/

68
cvae/README.md Normal file
View File

@ -0,0 +1,68 @@
# Convolutional Variational Autoencoder (CVAE) on MNIST
Convolutional variational autoencoder (CVAE) implementation in MLX using
MNIST.[^1]
## Setup
Install the requirements:
```
pip install -r requirements.txt
```
## Run
To train a VAE run:
```shell
python main.py
```
To see the supported options, do `python main.py -h`.
Training with the default options should give:
```shell
$ python train.py
Options:
Device: GPU
Seed: 0
Batch size: 128
Max number of filters: 64
Number of epochs: 50
Learning rate: 0.001
Number of latent dimensions: 8
Number of trainable params: 0.1493 M
Epoch 1 | Loss 14626.96 | Throughput 1803.44 im/s | Time 34.3 (s)
Epoch 2 | Loss 10462.21 | Throughput 1802.20 im/s | Time 34.3 (s)
...
Epoch 50 | Loss 8293.13 | Throughput 1804.91 im/s | Time 34.2 (s)
```
The throughput was measured on a 32GB M1 Max.
Reconstructed and generated images will be saved after each epoch in the
`models/` path. Below are examples of reconstructed training set images and
generated images.
#### Reconstruction
![MNIST Reconstructions](assets/rec_mnist.png)
#### Generation
![MNIST Samples](assets/samples_mnist.png)
## Limitations
At the time of writing, MLX does not have transposed 2D convolutions. The
example approximates them with a combination of nearest neighbor upsampling and
regular convolutions, similar to the original U-Net. We intend to update this
example once transposed 2D convolutions are available.
[^1]: For a good overview of VAEs see the original paper [Auto-Encoding
Variational Bayes](https://arxiv.org/abs/1312.6114) or [An Introduction to
Variational Autoencoders](https://arxiv.org/abs/1906.02691).

BIN
cvae/assets/rec_mnist.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 158 KiB

53
cvae/dataset.py Normal file
View File

@ -0,0 +1,53 @@
# Copyright © 2023-2024 Apple Inc.
from mlx.data.datasets import load_mnist
def mnist(batch_size, img_size, root=None):
# load train and test sets using mlx-data
load_fn = load_mnist
tr = load_fn(root=root, train=True)
test = load_fn(root=root, train=False)
# number of image channels is 1 for MNIST
num_img_channels = 1
# normalize to [0,1]
def normalize(x):
return x.astype("float32") / 255.0
# iterator over training set
tr_iter = (
tr.shuffle()
.to_stream()
.image_resize("image", h=img_size[0], w=img_size[1])
.key_transform("image", normalize)
.batch(batch_size)
)
# iterator over test set
test_iter = (
test.to_stream()
.image_resize("image", h=img_size[0], w=img_size[1])
.key_transform("image", normalize)
.batch(batch_size)
)
return tr_iter, test_iter
if __name__ == "__main__":
batch_size = 32
img_size = (64, 64) # (H, W)
tr_iter, test_iter = mnist(batch_size=batch_size, img_size=img_size)
B, H, W, C = batch_size, img_size[0], img_size[1], 1
print(f"Batch size: {B}, Channels: {C}, Height: {H}, Width: {W}")
batch_tr_iter = next(tr_iter)
assert batch_tr_iter["image"].shape == (B, H, W, C), "Wrong training set size"
assert batch_tr_iter["label"].shape == (batch_size,), "Wrong training set size"
batch_test_iter = next(test_iter)
assert batch_test_iter["image"].shape == (B, H, W, C), "Wrong training set size"
assert batch_test_iter["label"].shape == (batch_size,), "Wrong training set size"

229
cvae/main.py Normal file
View File

@ -0,0 +1,229 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import time
from pathlib import Path
import dataset
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import model
import numpy as np
from mlx.utils import tree_flatten
from PIL import Image
def grid_image_from_batch(image_batch, num_rows):
"""
Generate a grid image from a batch of images.
Assumes input has shape (B, H, W, C).
"""
B, H, W, _ = image_batch.shape
num_cols = B // num_rows
# Calculate the size of the output grid image
grid_height = num_rows * H
grid_width = num_cols * W
# Normalize and convert to the desired data type
image_batch = np.array(image_batch * 255).astype(np.uint8)
# Reshape the batch of images into a 2D grid
grid_image = image_batch.reshape(num_rows, num_cols, H, W, -1)
grid_image = grid_image.swapaxes(1, 2)
grid_image = grid_image.reshape(grid_height, grid_width, -1)
# Convert the grid to a PIL Image
return Image.fromarray(grid_image.squeeze())
def loss_fn(model, X):
X_recon, mu, logvar = model(X)
# Reconstruction loss
recon_loss = nn.losses.mse_loss(X_recon, X, reduction="sum")
# KL divergence between encoder distribution and standard normal:
kl_div = -0.5 * mx.sum(1 + logvar - mu.square() - logvar.exp())
# Total loss
return recon_loss + kl_div
def train_epoch(model, data, optimizer, epoch):
loss_acc = 0.0
throughput_acc = 0.0
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Iterate over training batches
for batch_count, batch in enumerate(data):
X = mx.array(batch["image"])
throughput_tic = time.perf_counter()
# Forward pass + backward pass + update
loss, grads = loss_and_grad_fn(model, X)
optimizer.update(model, grads)
# Evaluate updated model parameters
mx.eval(model.parameters(), optimizer.state)
throughput_toc = time.perf_counter()
throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
loss_acc += loss.item()
if batch_count > 0 and (batch_count % 10 == 0):
print(
" | ".join(
[
f"Epoch {epoch:4d}",
f"Loss {(loss_acc / batch_count):10.2f}",
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
f"Batch {batch_count:5d}",
]
),
end="\r",
)
return loss_acc, throughput_acc, batch_count
def reconstruct(model, batch, out_file):
# Reconstruct a single batch only
images = mx.array(batch["image"])
images_recon = model(images)[0]
paired_images = mx.stack([images, images_recon]).swapaxes(0, 1).flatten(0, 1)
grid_image = grid_image_from_batch(paired_images, num_rows=16)
grid_image.save(out_file)
def generate(
model,
out_file,
num_samples=128,
):
# Sample from the latent distribution:
z = mx.random.normal([num_samples, model.num_latent_dims])
# Decode the latent vectors to images:
images = model.decode(z)
# Save all images in a single file
grid_image = grid_image_from_batch(images, num_rows=8)
grid_image.save(out_file)
def main(args):
# Load the data
img_size = (64, 64, 1)
train_iter, test_iter = dataset.mnist(
batch_size=args.batch_size, img_size=img_size[:2]
)
save_dir = Path(args.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
# Load the model
vae = model.CVAE(args.latent_dims, img_size, args.max_filters)
mx.eval(vae.parameters())
num_params = sum(x.size for _, x in tree_flatten(vae.trainable_parameters()))
print("Number of trainable params: {:0.04f} M".format(num_params / 1e6))
optimizer = optim.AdamW(learning_rate=args.lr)
# Batches for reconstruction
train_batch = next(train_iter)
test_batch = next(test_iter)
for e in range(1, args.epochs + 1):
# Reset iterators and stats at the beginning of each epoch
train_iter.reset()
vae.train()
# Train one epoch
tic = time.perf_counter()
loss_acc, throughput_acc, batch_count = train_epoch(
vae, train_iter, optimizer, e
)
toc = time.perf_counter()
vae.eval()
print(
" | ".join(
[
f"Epoch {e:4d}",
f"Loss {(loss_acc / batch_count):10.2f}",
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
f"Time {toc - tic:8.1f} (s)",
]
)
)
# Reconstruct a batch of training and test images
reconstruct(vae, train_batch, save_dir / f"train_{e:03d}.png")
reconstruct(vae, test_batch, save_dir / f"test_{e:03d}.png")
# Generate images
generate(vae, save_dir / f"generated_{e:03d}.png")
vae.save_weights(str(save_dir / "weights.npz"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--cpu",
action="store_true",
help="Use CPU instead of GPU acceleration",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--batch-size", type=int, default=128, help="Batch size for training"
)
parser.add_argument(
"--max-filters",
type=int,
default=64,
help="Maximum number of filters in the convolutional layers",
)
parser.add_argument(
"--epochs", type=int, default=50, help="Number of training epochs"
)
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
parser.add_argument(
"--latent-dims",
type=int,
default=8,
help="Number of latent dimensions (positive integer)",
)
parser.add_argument(
"--save-dir",
type=str,
default="models/",
help="Path to save the model and reconstructed images.",
)
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
np.random.seed(args.seed)
mx.random.seed(args.seed)
print("Options: ")
print(f" Device: {'GPU' if not args.cpu else 'CPU'}")
print(f" Seed: {args.seed}")
print(f" Batch size: {args.batch_size}")
print(f" Max number of filters: {args.max_filters}")
print(f" Number of epochs: {args.epochs}")
print(f" Learning rate: {args.lr}")
print(f" Number of latent dimensions: {args.latent_dims}")
main(args)

172
cvae/model.py Normal file
View File

@ -0,0 +1,172 @@
# Copyright © 2023-2024 Apple Inc.
import math
import mlx.core as mx
import mlx.nn as nn
# from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py
def upsample_nearest(x, scale: int = 2):
B, H, W, C = x.shape
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
x = x.reshape(B, H * scale, W * scale, C)
return x
class UpsamplingConv2d(nn.Module):
"""
A convolutional layer that upsamples the input by a factor of 2. MLX does
not yet support transposed convolutions, so we approximate them with
nearest neighbor upsampling followed by a convolution. This is similar to
the approach used in the original U-Net.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super().__init__()
self.conv = nn.Conv2d(
in_channels, out_channels, kernel_size, stride=stride, padding=padding
)
def __call__(self, x):
x = self.conv(upsample_nearest(x))
return x
class Encoder(nn.Module):
"""
A convolutional variational encoder.
Maps the input to a normal distribution in latent space and sample a latent
vector from that distribution.
"""
def __init__(self, num_latent_dims, image_shape, max_num_filters):
super().__init__()
# number of filters in the convolutional layers
num_filters_1 = max_num_filters // 4
num_filters_2 = max_num_filters // 2
num_filters_3 = max_num_filters
# Output (BHWC): B x 32 x 32 x num_filters_1
self.conv1 = nn.Conv2d(image_shape[-1], num_filters_1, 3, stride=2, padding=1)
# Output (BHWC): B x 16 x 16 x num_filters_2
self.conv2 = nn.Conv2d(num_filters_1, num_filters_2, 3, stride=2, padding=1)
# Output (BHWC): B x 8 x 8 x num_filters_3
self.conv3 = nn.Conv2d(num_filters_2, num_filters_3, 3, stride=2, padding=1)
# Batch Normalization
self.bn1 = nn.BatchNorm(num_filters_1)
self.bn2 = nn.BatchNorm(num_filters_2)
self.bn3 = nn.BatchNorm(num_filters_3)
# Divide the spatial dimensions by 8 because of the 3 strided convolutions
output_shape = [num_filters_3] + [
dimension // 8 for dimension in image_shape[:-1]
]
flattened_dim = math.prod(output_shape)
# Linear mappings to mean and standard deviation
self.proj_mu = nn.Linear(flattened_dim, num_latent_dims)
self.proj_log_var = nn.Linear(flattened_dim, num_latent_dims)
def __call__(self, x):
x = nn.leaky_relu(self.bn1(self.conv1(x)))
x = nn.leaky_relu(self.bn2(self.conv2(x)))
x = nn.leaky_relu(self.bn3(self.conv3(x)))
x = mx.flatten(x, 1) # flatten all dimensions except batch
mu = self.proj_mu(x)
logvar = self.proj_log_var(x)
# Ensure this is the std deviation, not variance
sigma = mx.exp(logvar * 0.5)
# Generate a tensor of random values from a normal distribution
eps = mx.random.normal(sigma.shape)
# Reparametrization trick to brackpropagate through sampling.
z = eps * sigma + mu
return z, mu, logvar
class Decoder(nn.Module):
"""A convolutional decoder"""
def __init__(self, num_latent_dims, image_shape, max_num_filters):
super().__init__()
self.num_latent_dims = num_latent_dims
num_img_channels = image_shape[-1]
self.max_num_filters = max_num_filters
# decoder layers
num_filters_1 = max_num_filters
num_filters_2 = max_num_filters // 2
num_filters_3 = max_num_filters // 4
# divide the last two dimensions by 8 because of the 3 upsampling convolutions
self.input_shape = [dimension // 8 for dimension in image_shape[:-1]] + [
num_filters_1
]
flattened_dim = math.prod(self.input_shape)
# Output: flattened_dim
self.lin1 = nn.Linear(num_latent_dims, flattened_dim)
# Output (BHWC): B x 16 x 16 x num_filters_2
self.upconv1 = UpsamplingConv2d(
num_filters_1, num_filters_2, 3, stride=1, padding=1
)
# Output (BHWC): B x 32 x 32 x num_filters_1
self.upconv2 = UpsamplingConv2d(
num_filters_2, num_filters_3, 3, stride=1, padding=1
)
# Output (BHWC): B x 64 x 64 x #img_channels
self.upconv3 = UpsamplingConv2d(
num_filters_3, num_img_channels, 3, stride=1, padding=1
)
# Batch Normalizations
self.bn1 = nn.BatchNorm(num_filters_2)
self.bn2 = nn.BatchNorm(num_filters_3)
def __call__(self, z):
x = self.lin1(z)
# reshape to BHWC
x = x.reshape(
-1, self.input_shape[0], self.input_shape[1], self.max_num_filters
)
# approximate transposed convolutions with nearest neighbor upsampling
x = nn.leaky_relu(self.bn1(self.upconv1(x)))
x = nn.leaky_relu(self.bn2(self.upconv2(x)))
# sigmoid to ensure pixel values are in [0,1]
x = mx.sigmoid(self.upconv3(x))
return x
class CVAE(nn.Module):
"""
A convolutional variational autoencoder consisting of an encoder and a
decoder.
"""
def __init__(self, num_latent_dims, input_shape, max_num_filters):
super().__init__()
self.num_latent_dims = num_latent_dims
self.encoder = Encoder(num_latent_dims, input_shape, max_num_filters)
self.decoder = Decoder(num_latent_dims, input_shape, max_num_filters)
def __call__(self, x):
# image to latent vector
z, mu, logvar = self.encoder(x)
# latent vector to image
x = self.decode(z)
return x, mu, logvar
def encode(self, x):
return self.encoder(x)[0]
def decode(self, z):
return self.decoder(z)

4
cvae/requirements.txt Normal file
View File

@ -0,0 +1,4 @@
mlx>=0.0.9
mlx-data
numpy
Pillow

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from sys import exit
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
@ -6,8 +7,6 @@ import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from sys import exit
try: try:
import hf_olmo import hf_olmo
except ImportError: except ImportError: