mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Example of a Convolutional Variational Autoencoder (CVAE) on MNIST (#264)
* initial commit * style fixes * update of ACKNOWLEDGMENTS * fixed comment * minor refactoring; removed unused imports * added cifar and cvae to top-level README.md * removed mention of cuda/mps in argparse * fixed training status output * load_weights() with strict=True * pretrained model update * fixed imports and style * requires mlx>=0.0.9 * updated with results using mlx 0.0.9 * removed mention of private repo * simplify and combine to one file, more consistency with other exmaples * few more nits * nits * spell * format --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
8071aacd98
commit
9b387007ab
@ -11,3 +11,4 @@ MLX Examples was developed with contributions from the following individuals:
|
||||
- Sarthak Yadav: Added the `cifar` and `speechcommands` examples.
|
||||
- Shunta Saito: Added support for PLaMo models.
|
||||
- Gabrijel Boduljak: Implemented `CLIP`.
|
||||
- Markus Enzweiler: Added the `cvae` examples.
|
||||
|
@ -20,7 +20,9 @@ Some more useful examples are listed below.
|
||||
|
||||
### Image Models
|
||||
|
||||
- Image classification using [ResNets on CIFAR-10](cifar).
|
||||
- Generating images with [Stable Diffusion](stable_diffusion).
|
||||
- Convolutional variational autoencoder [(CVAE) on MNIST](cvae).
|
||||
|
||||
### Audio Models
|
||||
|
||||
|
1
cvae/.gitignore
vendored
Normal file
1
cvae/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
models/
|
68
cvae/README.md
Normal file
68
cvae/README.md
Normal file
@ -0,0 +1,68 @@
|
||||
# Convolutional Variational Autoencoder (CVAE) on MNIST
|
||||
|
||||
Convolutional variational autoencoder (CVAE) implementation in MLX using
|
||||
MNIST.[^1]
|
||||
|
||||
## Setup
|
||||
|
||||
Install the requirements:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
|
||||
To train a VAE run:
|
||||
|
||||
```shell
|
||||
python main.py
|
||||
```
|
||||
|
||||
To see the supported options, do `python main.py -h`.
|
||||
|
||||
Training with the default options should give:
|
||||
|
||||
```shell
|
||||
$ python train.py
|
||||
Options:
|
||||
Device: GPU
|
||||
Seed: 0
|
||||
Batch size: 128
|
||||
Max number of filters: 64
|
||||
Number of epochs: 50
|
||||
Learning rate: 0.001
|
||||
Number of latent dimensions: 8
|
||||
Number of trainable params: 0.1493 M
|
||||
Epoch 1 | Loss 14626.96 | Throughput 1803.44 im/s | Time 34.3 (s)
|
||||
Epoch 2 | Loss 10462.21 | Throughput 1802.20 im/s | Time 34.3 (s)
|
||||
...
|
||||
Epoch 50 | Loss 8293.13 | Throughput 1804.91 im/s | Time 34.2 (s)
|
||||
```
|
||||
|
||||
The throughput was measured on a 32GB M1 Max.
|
||||
|
||||
Reconstructed and generated images will be saved after each epoch in the
|
||||
`models/` path. Below are examples of reconstructed training set images and
|
||||
generated images.
|
||||
|
||||
#### Reconstruction
|
||||
|
||||

|
||||
|
||||
#### Generation
|
||||
|
||||

|
||||
|
||||
|
||||
## Limitations
|
||||
|
||||
At the time of writing, MLX does not have transposed 2D convolutions. The
|
||||
example approximates them with a combination of nearest neighbor upsampling and
|
||||
regular convolutions, similar to the original U-Net. We intend to update this
|
||||
example once transposed 2D convolutions are available.
|
||||
|
||||
[^1]: For a good overview of VAEs see the original paper [Auto-Encoding
|
||||
Variational Bayes](https://arxiv.org/abs/1312.6114) or [An Introduction to
|
||||
Variational Autoencoders](https://arxiv.org/abs/1906.02691).
|
BIN
cvae/assets/rec_mnist.png
Normal file
BIN
cvae/assets/rec_mnist.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 136 KiB |
BIN
cvae/assets/samples_mnist.png
Normal file
BIN
cvae/assets/samples_mnist.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 158 KiB |
53
cvae/dataset.py
Normal file
53
cvae/dataset.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from mlx.data.datasets import load_mnist
|
||||
|
||||
|
||||
def mnist(batch_size, img_size, root=None):
|
||||
# load train and test sets using mlx-data
|
||||
load_fn = load_mnist
|
||||
tr = load_fn(root=root, train=True)
|
||||
test = load_fn(root=root, train=False)
|
||||
|
||||
# number of image channels is 1 for MNIST
|
||||
num_img_channels = 1
|
||||
|
||||
# normalize to [0,1]
|
||||
def normalize(x):
|
||||
return x.astype("float32") / 255.0
|
||||
|
||||
# iterator over training set
|
||||
tr_iter = (
|
||||
tr.shuffle()
|
||||
.to_stream()
|
||||
.image_resize("image", h=img_size[0], w=img_size[1])
|
||||
.key_transform("image", normalize)
|
||||
.batch(batch_size)
|
||||
)
|
||||
|
||||
# iterator over test set
|
||||
test_iter = (
|
||||
test.to_stream()
|
||||
.image_resize("image", h=img_size[0], w=img_size[1])
|
||||
.key_transform("image", normalize)
|
||||
.batch(batch_size)
|
||||
)
|
||||
return tr_iter, test_iter
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_size = 32
|
||||
img_size = (64, 64) # (H, W)
|
||||
|
||||
tr_iter, test_iter = mnist(batch_size=batch_size, img_size=img_size)
|
||||
|
||||
B, H, W, C = batch_size, img_size[0], img_size[1], 1
|
||||
print(f"Batch size: {B}, Channels: {C}, Height: {H}, Width: {W}")
|
||||
|
||||
batch_tr_iter = next(tr_iter)
|
||||
assert batch_tr_iter["image"].shape == (B, H, W, C), "Wrong training set size"
|
||||
assert batch_tr_iter["label"].shape == (batch_size,), "Wrong training set size"
|
||||
|
||||
batch_test_iter = next(test_iter)
|
||||
assert batch_test_iter["image"].shape == (B, H, W, C), "Wrong training set size"
|
||||
assert batch_test_iter["label"].shape == (batch_size,), "Wrong training set size"
|
229
cvae/main.py
Normal file
229
cvae/main.py
Normal file
@ -0,0 +1,229 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import dataset
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import model
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def grid_image_from_batch(image_batch, num_rows):
|
||||
"""
|
||||
Generate a grid image from a batch of images.
|
||||
Assumes input has shape (B, H, W, C).
|
||||
"""
|
||||
|
||||
B, H, W, _ = image_batch.shape
|
||||
|
||||
num_cols = B // num_rows
|
||||
|
||||
# Calculate the size of the output grid image
|
||||
grid_height = num_rows * H
|
||||
grid_width = num_cols * W
|
||||
|
||||
# Normalize and convert to the desired data type
|
||||
image_batch = np.array(image_batch * 255).astype(np.uint8)
|
||||
|
||||
# Reshape the batch of images into a 2D grid
|
||||
grid_image = image_batch.reshape(num_rows, num_cols, H, W, -1)
|
||||
grid_image = grid_image.swapaxes(1, 2)
|
||||
grid_image = grid_image.reshape(grid_height, grid_width, -1)
|
||||
|
||||
# Convert the grid to a PIL Image
|
||||
return Image.fromarray(grid_image.squeeze())
|
||||
|
||||
|
||||
def loss_fn(model, X):
|
||||
X_recon, mu, logvar = model(X)
|
||||
|
||||
# Reconstruction loss
|
||||
recon_loss = nn.losses.mse_loss(X_recon, X, reduction="sum")
|
||||
|
||||
# KL divergence between encoder distribution and standard normal:
|
||||
kl_div = -0.5 * mx.sum(1 + logvar - mu.square() - logvar.exp())
|
||||
|
||||
# Total loss
|
||||
return recon_loss + kl_div
|
||||
|
||||
|
||||
def train_epoch(model, data, optimizer, epoch):
|
||||
loss_acc = 0.0
|
||||
throughput_acc = 0.0
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
|
||||
# Iterate over training batches
|
||||
for batch_count, batch in enumerate(data):
|
||||
X = mx.array(batch["image"])
|
||||
|
||||
throughput_tic = time.perf_counter()
|
||||
|
||||
# Forward pass + backward pass + update
|
||||
loss, grads = loss_and_grad_fn(model, X)
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Evaluate updated model parameters
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
throughput_toc = time.perf_counter()
|
||||
throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
|
||||
loss_acc += loss.item()
|
||||
|
||||
if batch_count > 0 and (batch_count % 10 == 0):
|
||||
print(
|
||||
" | ".join(
|
||||
[
|
||||
f"Epoch {epoch:4d}",
|
||||
f"Loss {(loss_acc / batch_count):10.2f}",
|
||||
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
|
||||
f"Batch {batch_count:5d}",
|
||||
]
|
||||
),
|
||||
end="\r",
|
||||
)
|
||||
|
||||
return loss_acc, throughput_acc, batch_count
|
||||
|
||||
|
||||
def reconstruct(model, batch, out_file):
|
||||
# Reconstruct a single batch only
|
||||
images = mx.array(batch["image"])
|
||||
images_recon = model(images)[0]
|
||||
paired_images = mx.stack([images, images_recon]).swapaxes(0, 1).flatten(0, 1)
|
||||
grid_image = grid_image_from_batch(paired_images, num_rows=16)
|
||||
grid_image.save(out_file)
|
||||
|
||||
|
||||
def generate(
|
||||
model,
|
||||
out_file,
|
||||
num_samples=128,
|
||||
):
|
||||
# Sample from the latent distribution:
|
||||
z = mx.random.normal([num_samples, model.num_latent_dims])
|
||||
|
||||
# Decode the latent vectors to images:
|
||||
images = model.decode(z)
|
||||
|
||||
# Save all images in a single file
|
||||
grid_image = grid_image_from_batch(images, num_rows=8)
|
||||
grid_image.save(out_file)
|
||||
|
||||
|
||||
def main(args):
|
||||
# Load the data
|
||||
img_size = (64, 64, 1)
|
||||
train_iter, test_iter = dataset.mnist(
|
||||
batch_size=args.batch_size, img_size=img_size[:2]
|
||||
)
|
||||
|
||||
save_dir = Path(args.save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load the model
|
||||
vae = model.CVAE(args.latent_dims, img_size, args.max_filters)
|
||||
mx.eval(vae.parameters())
|
||||
|
||||
num_params = sum(x.size for _, x in tree_flatten(vae.trainable_parameters()))
|
||||
print("Number of trainable params: {:0.04f} M".format(num_params / 1e6))
|
||||
|
||||
optimizer = optim.AdamW(learning_rate=args.lr)
|
||||
|
||||
# Batches for reconstruction
|
||||
train_batch = next(train_iter)
|
||||
test_batch = next(test_iter)
|
||||
|
||||
for e in range(1, args.epochs + 1):
|
||||
# Reset iterators and stats at the beginning of each epoch
|
||||
train_iter.reset()
|
||||
vae.train()
|
||||
|
||||
# Train one epoch
|
||||
tic = time.perf_counter()
|
||||
loss_acc, throughput_acc, batch_count = train_epoch(
|
||||
vae, train_iter, optimizer, e
|
||||
)
|
||||
toc = time.perf_counter()
|
||||
|
||||
vae.eval()
|
||||
|
||||
print(
|
||||
" | ".join(
|
||||
[
|
||||
f"Epoch {e:4d}",
|
||||
f"Loss {(loss_acc / batch_count):10.2f}",
|
||||
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
|
||||
f"Time {toc - tic:8.1f} (s)",
|
||||
]
|
||||
)
|
||||
)
|
||||
# Reconstruct a batch of training and test images
|
||||
reconstruct(vae, train_batch, save_dir / f"train_{e:03d}.png")
|
||||
reconstruct(vae, test_batch, save_dir / f"test_{e:03d}.png")
|
||||
|
||||
# Generate images
|
||||
generate(vae, save_dir / f"generated_{e:03d}.png")
|
||||
|
||||
vae.save_weights(str(save_dir / "weights.npz"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--cpu",
|
||||
action="store_true",
|
||||
help="Use CPU instead of GPU acceleration",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=128, help="Batch size for training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-filters",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Maximum number of filters in the convolutional layers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs", type=int, default=50, help="Number of training epochs"
|
||||
)
|
||||
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
|
||||
|
||||
parser.add_argument(
|
||||
"--latent-dims",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of latent dimensions (positive integer)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-dir",
|
||||
type=str,
|
||||
default="models/",
|
||||
help="Path to save the model and reconstructed images.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
np.random.seed(args.seed)
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
print("Options: ")
|
||||
print(f" Device: {'GPU' if not args.cpu else 'CPU'}")
|
||||
print(f" Seed: {args.seed}")
|
||||
print(f" Batch size: {args.batch_size}")
|
||||
print(f" Max number of filters: {args.max_filters}")
|
||||
print(f" Number of epochs: {args.epochs}")
|
||||
print(f" Learning rate: {args.lr}")
|
||||
print(f" Number of latent dimensions: {args.latent_dims}")
|
||||
|
||||
main(args)
|
172
cvae/model.py
Normal file
172
cvae/model.py
Normal file
@ -0,0 +1,172 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
# from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py
|
||||
def upsample_nearest(x, scale: int = 2):
|
||||
B, H, W, C = x.shape
|
||||
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
|
||||
x = x.reshape(B, H * scale, W * scale, C)
|
||||
return x
|
||||
|
||||
|
||||
class UpsamplingConv2d(nn.Module):
|
||||
"""
|
||||
A convolutional layer that upsamples the input by a factor of 2. MLX does
|
||||
not yet support transposed convolutions, so we approximate them with
|
||||
nearest neighbor upsampling followed by a convolution. This is similar to
|
||||
the approach used in the original U-Net.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size, stride=stride, padding=padding
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv(upsample_nearest(x))
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""
|
||||
A convolutional variational encoder.
|
||||
Maps the input to a normal distribution in latent space and sample a latent
|
||||
vector from that distribution.
|
||||
"""
|
||||
|
||||
def __init__(self, num_latent_dims, image_shape, max_num_filters):
|
||||
super().__init__()
|
||||
|
||||
# number of filters in the convolutional layers
|
||||
num_filters_1 = max_num_filters // 4
|
||||
num_filters_2 = max_num_filters // 2
|
||||
num_filters_3 = max_num_filters
|
||||
|
||||
# Output (BHWC): B x 32 x 32 x num_filters_1
|
||||
self.conv1 = nn.Conv2d(image_shape[-1], num_filters_1, 3, stride=2, padding=1)
|
||||
# Output (BHWC): B x 16 x 16 x num_filters_2
|
||||
self.conv2 = nn.Conv2d(num_filters_1, num_filters_2, 3, stride=2, padding=1)
|
||||
# Output (BHWC): B x 8 x 8 x num_filters_3
|
||||
self.conv3 = nn.Conv2d(num_filters_2, num_filters_3, 3, stride=2, padding=1)
|
||||
|
||||
# Batch Normalization
|
||||
self.bn1 = nn.BatchNorm(num_filters_1)
|
||||
self.bn2 = nn.BatchNorm(num_filters_2)
|
||||
self.bn3 = nn.BatchNorm(num_filters_3)
|
||||
|
||||
# Divide the spatial dimensions by 8 because of the 3 strided convolutions
|
||||
output_shape = [num_filters_3] + [
|
||||
dimension // 8 for dimension in image_shape[:-1]
|
||||
]
|
||||
|
||||
flattened_dim = math.prod(output_shape)
|
||||
|
||||
# Linear mappings to mean and standard deviation
|
||||
self.proj_mu = nn.Linear(flattened_dim, num_latent_dims)
|
||||
self.proj_log_var = nn.Linear(flattened_dim, num_latent_dims)
|
||||
|
||||
def __call__(self, x):
|
||||
x = nn.leaky_relu(self.bn1(self.conv1(x)))
|
||||
x = nn.leaky_relu(self.bn2(self.conv2(x)))
|
||||
x = nn.leaky_relu(self.bn3(self.conv3(x)))
|
||||
x = mx.flatten(x, 1) # flatten all dimensions except batch
|
||||
|
||||
mu = self.proj_mu(x)
|
||||
logvar = self.proj_log_var(x)
|
||||
# Ensure this is the std deviation, not variance
|
||||
sigma = mx.exp(logvar * 0.5)
|
||||
|
||||
# Generate a tensor of random values from a normal distribution
|
||||
eps = mx.random.normal(sigma.shape)
|
||||
|
||||
# Reparametrization trick to brackpropagate through sampling.
|
||||
z = eps * sigma + mu
|
||||
|
||||
return z, mu, logvar
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""A convolutional decoder"""
|
||||
|
||||
def __init__(self, num_latent_dims, image_shape, max_num_filters):
|
||||
super().__init__()
|
||||
self.num_latent_dims = num_latent_dims
|
||||
num_img_channels = image_shape[-1]
|
||||
self.max_num_filters = max_num_filters
|
||||
|
||||
# decoder layers
|
||||
num_filters_1 = max_num_filters
|
||||
num_filters_2 = max_num_filters // 2
|
||||
num_filters_3 = max_num_filters // 4
|
||||
|
||||
# divide the last two dimensions by 8 because of the 3 upsampling convolutions
|
||||
self.input_shape = [dimension // 8 for dimension in image_shape[:-1]] + [
|
||||
num_filters_1
|
||||
]
|
||||
flattened_dim = math.prod(self.input_shape)
|
||||
|
||||
# Output: flattened_dim
|
||||
self.lin1 = nn.Linear(num_latent_dims, flattened_dim)
|
||||
# Output (BHWC): B x 16 x 16 x num_filters_2
|
||||
self.upconv1 = UpsamplingConv2d(
|
||||
num_filters_1, num_filters_2, 3, stride=1, padding=1
|
||||
)
|
||||
# Output (BHWC): B x 32 x 32 x num_filters_1
|
||||
self.upconv2 = UpsamplingConv2d(
|
||||
num_filters_2, num_filters_3, 3, stride=1, padding=1
|
||||
)
|
||||
# Output (BHWC): B x 64 x 64 x #img_channels
|
||||
self.upconv3 = UpsamplingConv2d(
|
||||
num_filters_3, num_img_channels, 3, stride=1, padding=1
|
||||
)
|
||||
|
||||
# Batch Normalizations
|
||||
self.bn1 = nn.BatchNorm(num_filters_2)
|
||||
self.bn2 = nn.BatchNorm(num_filters_3)
|
||||
|
||||
def __call__(self, z):
|
||||
x = self.lin1(z)
|
||||
|
||||
# reshape to BHWC
|
||||
x = x.reshape(
|
||||
-1, self.input_shape[0], self.input_shape[1], self.max_num_filters
|
||||
)
|
||||
|
||||
# approximate transposed convolutions with nearest neighbor upsampling
|
||||
x = nn.leaky_relu(self.bn1(self.upconv1(x)))
|
||||
x = nn.leaky_relu(self.bn2(self.upconv2(x)))
|
||||
# sigmoid to ensure pixel values are in [0,1]
|
||||
x = mx.sigmoid(self.upconv3(x))
|
||||
return x
|
||||
|
||||
|
||||
class CVAE(nn.Module):
|
||||
"""
|
||||
A convolutional variational autoencoder consisting of an encoder and a
|
||||
decoder.
|
||||
"""
|
||||
|
||||
def __init__(self, num_latent_dims, input_shape, max_num_filters):
|
||||
super().__init__()
|
||||
self.num_latent_dims = num_latent_dims
|
||||
self.encoder = Encoder(num_latent_dims, input_shape, max_num_filters)
|
||||
self.decoder = Decoder(num_latent_dims, input_shape, max_num_filters)
|
||||
|
||||
def __call__(self, x):
|
||||
# image to latent vector
|
||||
z, mu, logvar = self.encoder(x)
|
||||
# latent vector to image
|
||||
x = self.decode(z)
|
||||
return x, mu, logvar
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)[0]
|
||||
|
||||
def decode(self, z):
|
||||
return self.decoder(z)
|
4
cvae/requirements.txt
Normal file
4
cvae/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
||||
mlx>=0.0.9
|
||||
mlx-data
|
||||
numpy
|
||||
Pillow
|
@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from sys import exit
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
@ -6,8 +7,6 @@ import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
|
||||
from sys import exit
|
||||
|
||||
try:
|
||||
import hf_olmo
|
||||
except ImportError:
|
||||
|
Loading…
Reference in New Issue
Block a user