
* initial commit * style fixes * update of ACKNOWLEDGMENTS * fixed comment * minor refactoring; removed unused imports * added cifar and cvae to top-level README.md * removed mention of cuda/mps in argparse * fixed training status output * load_weights() with strict=True * pretrained model update * fixed imports and style * requires mlx>=0.0.9 * updated with results using mlx 0.0.9 * removed mention of private repo * simplify and combine to one file, more consistency with other exmaples * few more nits * nits * spell * format --------- Co-authored-by: Awni Hannun <awni@apple.com>
1.7 KiB
Convolutional Variational Autoencoder (CVAE) on MNIST
Convolutional variational autoencoder (CVAE) implementation in MLX using MNIST.1
Setup
Install the requirements:
pip install -r requirements.txt
Run
To train a VAE run:
python main.py
To see the supported options, do python main.py -h
.
Training with the default options should give:
$ python train.py
Options:
Device: GPU
Seed: 0
Batch size: 128
Max number of filters: 64
Number of epochs: 50
Learning rate: 0.001
Number of latent dimensions: 8
Number of trainable params: 0.1493 M
Epoch 1 | Loss 14626.96 | Throughput 1803.44 im/s | Time 34.3 (s)
Epoch 2 | Loss 10462.21 | Throughput 1802.20 im/s | Time 34.3 (s)
...
Epoch 50 | Loss 8293.13 | Throughput 1804.91 im/s | Time 34.2 (s)
The throughput was measured on a 32GB M1 Max.
Reconstructed and generated images will be saved after each epoch in the
models/
path. Below are examples of reconstructed training set images and
generated images.
Reconstruction
Generation
Limitations
At the time of writing, MLX does not have transposed 2D convolutions. The example approximates them with a combination of nearest neighbor upsampling and regular convolutions, similar to the original U-Net. We intend to update this example once transposed 2D convolutions are available.
-
For a good overview of VAEs see the original paper Auto-Encoding Variational Bayes or An Introduction to Variational Autoencoders. ↩︎