diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 3f445c0b..6bd419a9 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -11,3 +11,4 @@ MLX Examples was developed with contributions from the following individuals: - Sarthak Yadav: Added the `cifar` and `speechcommands` examples. - Shunta Saito: Added support for PLaMo models. - Gabrijel Boduljak: Implemented `CLIP`. +- Markus Enzweiler: Added the `cvae` examples. diff --git a/README.md b/README.md index 653cbf01..15c9cca1 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,9 @@ Some more useful examples are listed below. ### Image Models +- Image classification using [ResNets on CIFAR-10](cifar). - Generating images with [Stable Diffusion](stable_diffusion). +- Convolutional variational autoencoder [(CVAE) on MNIST](cvae). ### Audio Models diff --git a/cvae/.gitignore b/cvae/.gitignore new file mode 100644 index 00000000..2bcdfd92 --- /dev/null +++ b/cvae/.gitignore @@ -0,0 +1 @@ +models/ diff --git a/cvae/README.md b/cvae/README.md new file mode 100644 index 00000000..8cfa26ae --- /dev/null +++ b/cvae/README.md @@ -0,0 +1,68 @@ +# Convolutional Variational Autoencoder (CVAE) on MNIST + +Convolutional variational autoencoder (CVAE) implementation in MLX using +MNIST.[^1] + +## Setup + +Install the requirements: + +``` +pip install -r requirements.txt +``` + +## Run + + +To train a VAE run: + +```shell +python main.py +``` + +To see the supported options, do `python main.py -h`. + +Training with the default options should give: + +```shell +$ python train.py +Options: + Device: GPU + Seed: 0 + Batch size: 128 + Max number of filters: 64 + Number of epochs: 50 + Learning rate: 0.001 + Number of latent dimensions: 8 +Number of trainable params: 0.1493 M +Epoch 1 | Loss 14626.96 | Throughput 1803.44 im/s | Time 34.3 (s) +Epoch 2 | Loss 10462.21 | Throughput 1802.20 im/s | Time 34.3 (s) +... +Epoch 50 | Loss 8293.13 | Throughput 1804.91 im/s | Time 34.2 (s) +``` + +The throughput was measured on a 32GB M1 Max. + +Reconstructed and generated images will be saved after each epoch in the +`models/` path. Below are examples of reconstructed training set images and +generated images. + +#### Reconstruction + +![MNIST Reconstructions](assets/rec_mnist.png) + +#### Generation + +![MNIST Samples](assets/samples_mnist.png) + + +## Limitations + +At the time of writing, MLX does not have transposed 2D convolutions. The +example approximates them with a combination of nearest neighbor upsampling and +regular convolutions, similar to the original U-Net. We intend to update this +example once transposed 2D convolutions are available. + +[^1]: For a good overview of VAEs see the original paper [Auto-Encoding + Variational Bayes](https://arxiv.org/abs/1312.6114) or [An Introduction to + Variational Autoencoders](https://arxiv.org/abs/1906.02691). diff --git a/cvae/assets/rec_mnist.png b/cvae/assets/rec_mnist.png new file mode 100644 index 00000000..5dd24986 Binary files /dev/null and b/cvae/assets/rec_mnist.png differ diff --git a/cvae/assets/samples_mnist.png b/cvae/assets/samples_mnist.png new file mode 100644 index 00000000..b2ab3078 Binary files /dev/null and b/cvae/assets/samples_mnist.png differ diff --git a/cvae/dataset.py b/cvae/dataset.py new file mode 100644 index 00000000..3af2ca32 --- /dev/null +++ b/cvae/dataset.py @@ -0,0 +1,53 @@ +# Copyright © 2023-2024 Apple Inc. + +from mlx.data.datasets import load_mnist + + +def mnist(batch_size, img_size, root=None): + # load train and test sets using mlx-data + load_fn = load_mnist + tr = load_fn(root=root, train=True) + test = load_fn(root=root, train=False) + + # number of image channels is 1 for MNIST + num_img_channels = 1 + + # normalize to [0,1] + def normalize(x): + return x.astype("float32") / 255.0 + + # iterator over training set + tr_iter = ( + tr.shuffle() + .to_stream() + .image_resize("image", h=img_size[0], w=img_size[1]) + .key_transform("image", normalize) + .batch(batch_size) + ) + + # iterator over test set + test_iter = ( + test.to_stream() + .image_resize("image", h=img_size[0], w=img_size[1]) + .key_transform("image", normalize) + .batch(batch_size) + ) + return tr_iter, test_iter + + +if __name__ == "__main__": + batch_size = 32 + img_size = (64, 64) # (H, W) + + tr_iter, test_iter = mnist(batch_size=batch_size, img_size=img_size) + + B, H, W, C = batch_size, img_size[0], img_size[1], 1 + print(f"Batch size: {B}, Channels: {C}, Height: {H}, Width: {W}") + + batch_tr_iter = next(tr_iter) + assert batch_tr_iter["image"].shape == (B, H, W, C), "Wrong training set size" + assert batch_tr_iter["label"].shape == (batch_size,), "Wrong training set size" + + batch_test_iter = next(test_iter) + assert batch_test_iter["image"].shape == (B, H, W, C), "Wrong training set size" + assert batch_test_iter["label"].shape == (batch_size,), "Wrong training set size" diff --git a/cvae/main.py b/cvae/main.py new file mode 100644 index 00000000..7c395a2d --- /dev/null +++ b/cvae/main.py @@ -0,0 +1,229 @@ +# Copyright © 2023-2024 Apple Inc. + +import argparse +import time +from pathlib import Path + +import dataset +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +import model +import numpy as np +from mlx.utils import tree_flatten +from PIL import Image + + +def grid_image_from_batch(image_batch, num_rows): + """ + Generate a grid image from a batch of images. + Assumes input has shape (B, H, W, C). + """ + + B, H, W, _ = image_batch.shape + + num_cols = B // num_rows + + # Calculate the size of the output grid image + grid_height = num_rows * H + grid_width = num_cols * W + + # Normalize and convert to the desired data type + image_batch = np.array(image_batch * 255).astype(np.uint8) + + # Reshape the batch of images into a 2D grid + grid_image = image_batch.reshape(num_rows, num_cols, H, W, -1) + grid_image = grid_image.swapaxes(1, 2) + grid_image = grid_image.reshape(grid_height, grid_width, -1) + + # Convert the grid to a PIL Image + return Image.fromarray(grid_image.squeeze()) + + +def loss_fn(model, X): + X_recon, mu, logvar = model(X) + + # Reconstruction loss + recon_loss = nn.losses.mse_loss(X_recon, X, reduction="sum") + + # KL divergence between encoder distribution and standard normal: + kl_div = -0.5 * mx.sum(1 + logvar - mu.square() - logvar.exp()) + + # Total loss + return recon_loss + kl_div + + +def train_epoch(model, data, optimizer, epoch): + loss_acc = 0.0 + throughput_acc = 0.0 + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + + # Iterate over training batches + for batch_count, batch in enumerate(data): + X = mx.array(batch["image"]) + + throughput_tic = time.perf_counter() + + # Forward pass + backward pass + update + loss, grads = loss_and_grad_fn(model, X) + optimizer.update(model, grads) + + # Evaluate updated model parameters + mx.eval(model.parameters(), optimizer.state) + + throughput_toc = time.perf_counter() + throughput_acc += X.shape[0] / (throughput_toc - throughput_tic) + loss_acc += loss.item() + + if batch_count > 0 and (batch_count % 10 == 0): + print( + " | ".join( + [ + f"Epoch {epoch:4d}", + f"Loss {(loss_acc / batch_count):10.2f}", + f"Throughput {(throughput_acc / batch_count):8.2f} im/s", + f"Batch {batch_count:5d}", + ] + ), + end="\r", + ) + + return loss_acc, throughput_acc, batch_count + + +def reconstruct(model, batch, out_file): + # Reconstruct a single batch only + images = mx.array(batch["image"]) + images_recon = model(images)[0] + paired_images = mx.stack([images, images_recon]).swapaxes(0, 1).flatten(0, 1) + grid_image = grid_image_from_batch(paired_images, num_rows=16) + grid_image.save(out_file) + + +def generate( + model, + out_file, + num_samples=128, +): + # Sample from the latent distribution: + z = mx.random.normal([num_samples, model.num_latent_dims]) + + # Decode the latent vectors to images: + images = model.decode(z) + + # Save all images in a single file + grid_image = grid_image_from_batch(images, num_rows=8) + grid_image.save(out_file) + + +def main(args): + # Load the data + img_size = (64, 64, 1) + train_iter, test_iter = dataset.mnist( + batch_size=args.batch_size, img_size=img_size[:2] + ) + + save_dir = Path(args.save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + + # Load the model + vae = model.CVAE(args.latent_dims, img_size, args.max_filters) + mx.eval(vae.parameters()) + + num_params = sum(x.size for _, x in tree_flatten(vae.trainable_parameters())) + print("Number of trainable params: {:0.04f} M".format(num_params / 1e6)) + + optimizer = optim.AdamW(learning_rate=args.lr) + + # Batches for reconstruction + train_batch = next(train_iter) + test_batch = next(test_iter) + + for e in range(1, args.epochs + 1): + # Reset iterators and stats at the beginning of each epoch + train_iter.reset() + vae.train() + + # Train one epoch + tic = time.perf_counter() + loss_acc, throughput_acc, batch_count = train_epoch( + vae, train_iter, optimizer, e + ) + toc = time.perf_counter() + + vae.eval() + + print( + " | ".join( + [ + f"Epoch {e:4d}", + f"Loss {(loss_acc / batch_count):10.2f}", + f"Throughput {(throughput_acc / batch_count):8.2f} im/s", + f"Time {toc - tic:8.1f} (s)", + ] + ) + ) + # Reconstruct a batch of training and test images + reconstruct(vae, train_batch, save_dir / f"train_{e:03d}.png") + reconstruct(vae, test_batch, save_dir / f"test_{e:03d}.png") + + # Generate images + generate(vae, save_dir / f"generated_{e:03d}.png") + + vae.save_weights(str(save_dir / "weights.npz")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--cpu", + action="store_true", + help="Use CPU instead of GPU acceleration", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--batch-size", type=int, default=128, help="Batch size for training" + ) + parser.add_argument( + "--max-filters", + type=int, + default=64, + help="Maximum number of filters in the convolutional layers", + ) + parser.add_argument( + "--epochs", type=int, default=50, help="Number of training epochs" + ) + parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate") + + parser.add_argument( + "--latent-dims", + type=int, + default=8, + help="Number of latent dimensions (positive integer)", + ) + parser.add_argument( + "--save-dir", + type=str, + default="models/", + help="Path to save the model and reconstructed images.", + ) + + args = parser.parse_args() + + if args.cpu: + mx.set_default_device(mx.cpu) + + np.random.seed(args.seed) + mx.random.seed(args.seed) + + print("Options: ") + print(f" Device: {'GPU' if not args.cpu else 'CPU'}") + print(f" Seed: {args.seed}") + print(f" Batch size: {args.batch_size}") + print(f" Max number of filters: {args.max_filters}") + print(f" Number of epochs: {args.epochs}") + print(f" Learning rate: {args.lr}") + print(f" Number of latent dimensions: {args.latent_dims}") + + main(args) diff --git a/cvae/model.py b/cvae/model.py new file mode 100644 index 00000000..58352de8 --- /dev/null +++ b/cvae/model.py @@ -0,0 +1,172 @@ +# Copyright © 2023-2024 Apple Inc. + +import math + +import mlx.core as mx +import mlx.nn as nn + + +# from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py +def upsample_nearest(x, scale: int = 2): + B, H, W, C = x.shape + x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C)) + x = x.reshape(B, H * scale, W * scale, C) + return x + + +class UpsamplingConv2d(nn.Module): + """ + A convolutional layer that upsamples the input by a factor of 2. MLX does + not yet support transposed convolutions, so we approximate them with + nearest neighbor upsampling followed by a convolution. This is similar to + the approach used in the original U-Net. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding): + super().__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=padding + ) + + def __call__(self, x): + x = self.conv(upsample_nearest(x)) + return x + + +class Encoder(nn.Module): + """ + A convolutional variational encoder. + Maps the input to a normal distribution in latent space and sample a latent + vector from that distribution. + """ + + def __init__(self, num_latent_dims, image_shape, max_num_filters): + super().__init__() + + # number of filters in the convolutional layers + num_filters_1 = max_num_filters // 4 + num_filters_2 = max_num_filters // 2 + num_filters_3 = max_num_filters + + # Output (BHWC): B x 32 x 32 x num_filters_1 + self.conv1 = nn.Conv2d(image_shape[-1], num_filters_1, 3, stride=2, padding=1) + # Output (BHWC): B x 16 x 16 x num_filters_2 + self.conv2 = nn.Conv2d(num_filters_1, num_filters_2, 3, stride=2, padding=1) + # Output (BHWC): B x 8 x 8 x num_filters_3 + self.conv3 = nn.Conv2d(num_filters_2, num_filters_3, 3, stride=2, padding=1) + + # Batch Normalization + self.bn1 = nn.BatchNorm(num_filters_1) + self.bn2 = nn.BatchNorm(num_filters_2) + self.bn3 = nn.BatchNorm(num_filters_3) + + # Divide the spatial dimensions by 8 because of the 3 strided convolutions + output_shape = [num_filters_3] + [ + dimension // 8 for dimension in image_shape[:-1] + ] + + flattened_dim = math.prod(output_shape) + + # Linear mappings to mean and standard deviation + self.proj_mu = nn.Linear(flattened_dim, num_latent_dims) + self.proj_log_var = nn.Linear(flattened_dim, num_latent_dims) + + def __call__(self, x): + x = nn.leaky_relu(self.bn1(self.conv1(x))) + x = nn.leaky_relu(self.bn2(self.conv2(x))) + x = nn.leaky_relu(self.bn3(self.conv3(x))) + x = mx.flatten(x, 1) # flatten all dimensions except batch + + mu = self.proj_mu(x) + logvar = self.proj_log_var(x) + # Ensure this is the std deviation, not variance + sigma = mx.exp(logvar * 0.5) + + # Generate a tensor of random values from a normal distribution + eps = mx.random.normal(sigma.shape) + + # Reparametrization trick to brackpropagate through sampling. + z = eps * sigma + mu + + return z, mu, logvar + + +class Decoder(nn.Module): + """A convolutional decoder""" + + def __init__(self, num_latent_dims, image_shape, max_num_filters): + super().__init__() + self.num_latent_dims = num_latent_dims + num_img_channels = image_shape[-1] + self.max_num_filters = max_num_filters + + # decoder layers + num_filters_1 = max_num_filters + num_filters_2 = max_num_filters // 2 + num_filters_3 = max_num_filters // 4 + + # divide the last two dimensions by 8 because of the 3 upsampling convolutions + self.input_shape = [dimension // 8 for dimension in image_shape[:-1]] + [ + num_filters_1 + ] + flattened_dim = math.prod(self.input_shape) + + # Output: flattened_dim + self.lin1 = nn.Linear(num_latent_dims, flattened_dim) + # Output (BHWC): B x 16 x 16 x num_filters_2 + self.upconv1 = UpsamplingConv2d( + num_filters_1, num_filters_2, 3, stride=1, padding=1 + ) + # Output (BHWC): B x 32 x 32 x num_filters_1 + self.upconv2 = UpsamplingConv2d( + num_filters_2, num_filters_3, 3, stride=1, padding=1 + ) + # Output (BHWC): B x 64 x 64 x #img_channels + self.upconv3 = UpsamplingConv2d( + num_filters_3, num_img_channels, 3, stride=1, padding=1 + ) + + # Batch Normalizations + self.bn1 = nn.BatchNorm(num_filters_2) + self.bn2 = nn.BatchNorm(num_filters_3) + + def __call__(self, z): + x = self.lin1(z) + + # reshape to BHWC + x = x.reshape( + -1, self.input_shape[0], self.input_shape[1], self.max_num_filters + ) + + # approximate transposed convolutions with nearest neighbor upsampling + x = nn.leaky_relu(self.bn1(self.upconv1(x))) + x = nn.leaky_relu(self.bn2(self.upconv2(x))) + # sigmoid to ensure pixel values are in [0,1] + x = mx.sigmoid(self.upconv3(x)) + return x + + +class CVAE(nn.Module): + """ + A convolutional variational autoencoder consisting of an encoder and a + decoder. + """ + + def __init__(self, num_latent_dims, input_shape, max_num_filters): + super().__init__() + self.num_latent_dims = num_latent_dims + self.encoder = Encoder(num_latent_dims, input_shape, max_num_filters) + self.decoder = Decoder(num_latent_dims, input_shape, max_num_filters) + + def __call__(self, x): + # image to latent vector + z, mu, logvar = self.encoder(x) + # latent vector to image + x = self.decode(z) + return x, mu, logvar + + def encode(self, x): + return self.encoder(x)[0] + + def decode(self, z): + return self.decoder(z) diff --git a/cvae/requirements.txt b/cvae/requirements.txt new file mode 100644 index 00000000..0fb1d31e --- /dev/null +++ b/cvae/requirements.txt @@ -0,0 +1,4 @@ +mlx>=0.0.9 +mlx-data +numpy +Pillow diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 95dec5c6..629ebe99 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from sys import exit from typing import Dict, Optional, Tuple, Union import mlx.core as mx @@ -6,8 +7,6 @@ import mlx.nn as nn from .base import BaseModelArgs -from sys import exit - try: import hf_olmo except ImportError: