mlx-examples/cvae/README.md
Markus Enzweiler 9b387007ab
Example of a Convolutional Variational Autoencoder (CVAE) on MNIST (#264)
* initial commit

* style fixes

* update of ACKNOWLEDGMENTS

* fixed comment

* minor refactoring; removed unused imports

* added cifar and cvae to top-level README.md

* removed mention of cuda/mps in argparse

* fixed training status output

* load_weights() with strict=True

* pretrained model update

* fixed imports and style

* requires mlx>=0.0.9

* updated with results using mlx 0.0.9

* removed mention of private repo

* simplify and combine to one file, more consistency with other exmaples

* few more nits

* nits

* spell

* format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-02-06 20:02:27 -08:00

69 lines
1.7 KiB
Markdown
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Convolutional Variational Autoencoder (CVAE) on MNIST
Convolutional variational autoencoder (CVAE) implementation in MLX using
MNIST.[^1]
## Setup
Install the requirements:
```
pip install -r requirements.txt
```
## Run
To train a VAE run:
```shell
python main.py
```
To see the supported options, do `python main.py -h`.
Training with the default options should give:
```shell
$ python train.py
Options:
Device: GPU
Seed: 0
Batch size: 128
Max number of filters: 64
Number of epochs: 50
Learning rate: 0.001
Number of latent dimensions: 8
Number of trainable params: 0.1493 M
Epoch 1 | Loss 14626.96 | Throughput 1803.44 im/s | Time 34.3 (s)
Epoch 2 | Loss 10462.21 | Throughput 1802.20 im/s | Time 34.3 (s)
...
Epoch 50 | Loss 8293.13 | Throughput 1804.91 im/s | Time 34.2 (s)
```
The throughput was measured on a 32GB M1 Max.
Reconstructed and generated images will be saved after each epoch in the
`models/` path. Below are examples of reconstructed training set images and
generated images.
#### Reconstruction
![MNIST Reconstructions](assets/rec_mnist.png)
#### Generation
![MNIST Samples](assets/samples_mnist.png)
## Limitations
At the time of writing, MLX does not have transposed 2D convolutions. The
example approximates them with a combination of nearest neighbor upsampling and
regular convolutions, similar to the original U-Net. We intend to update this
example once transposed 2D convolutions are available.
[^1]: For a good overview of VAEs see the original paper [Auto-Encoding
Variational Bayes](https://arxiv.org/abs/1312.6114) or [An Introduction to
Variational Autoencoders](https://arxiv.org/abs/1906.02691).