mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Example of a Convolutional Variational Autoencoder (CVAE) on MNIST (#264)
* initial commit * style fixes * update of ACKNOWLEDGMENTS * fixed comment * minor refactoring; removed unused imports * added cifar and cvae to top-level README.md * removed mention of cuda/mps in argparse * fixed training status output * load_weights() with strict=True * pretrained model update * fixed imports and style * requires mlx>=0.0.9 * updated with results using mlx 0.0.9 * removed mention of private repo * simplify and combine to one file, more consistency with other exmaples * few more nits * nits * spell * format --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
68
cvae/README.md
Normal file
68
cvae/README.md
Normal file
@@ -0,0 +1,68 @@
|
||||
# Convolutional Variational Autoencoder (CVAE) on MNIST
|
||||
|
||||
Convolutional variational autoencoder (CVAE) implementation in MLX using
|
||||
MNIST.[^1]
|
||||
|
||||
## Setup
|
||||
|
||||
Install the requirements:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
|
||||
To train a VAE run:
|
||||
|
||||
```shell
|
||||
python main.py
|
||||
```
|
||||
|
||||
To see the supported options, do `python main.py -h`.
|
||||
|
||||
Training with the default options should give:
|
||||
|
||||
```shell
|
||||
$ python train.py
|
||||
Options:
|
||||
Device: GPU
|
||||
Seed: 0
|
||||
Batch size: 128
|
||||
Max number of filters: 64
|
||||
Number of epochs: 50
|
||||
Learning rate: 0.001
|
||||
Number of latent dimensions: 8
|
||||
Number of trainable params: 0.1493 M
|
||||
Epoch 1 | Loss 14626.96 | Throughput 1803.44 im/s | Time 34.3 (s)
|
||||
Epoch 2 | Loss 10462.21 | Throughput 1802.20 im/s | Time 34.3 (s)
|
||||
...
|
||||
Epoch 50 | Loss 8293.13 | Throughput 1804.91 im/s | Time 34.2 (s)
|
||||
```
|
||||
|
||||
The throughput was measured on a 32GB M1 Max.
|
||||
|
||||
Reconstructed and generated images will be saved after each epoch in the
|
||||
`models/` path. Below are examples of reconstructed training set images and
|
||||
generated images.
|
||||
|
||||
#### Reconstruction
|
||||
|
||||

|
||||
|
||||
#### Generation
|
||||
|
||||

|
||||
|
||||
|
||||
## Limitations
|
||||
|
||||
At the time of writing, MLX does not have transposed 2D convolutions. The
|
||||
example approximates them with a combination of nearest neighbor upsampling and
|
||||
regular convolutions, similar to the original U-Net. We intend to update this
|
||||
example once transposed 2D convolutions are available.
|
||||
|
||||
[^1]: For a good overview of VAEs see the original paper [Auto-Encoding
|
||||
Variational Bayes](https://arxiv.org/abs/1312.6114) or [An Introduction to
|
||||
Variational Autoencoders](https://arxiv.org/abs/1906.02691).
|
||||
Reference in New Issue
Block a user