mlx-examples/normalizing_flow
Awni Hannun f45a1ab83c
Update a few examples to use compile (#420)
* update a few examples to use compile

* update mnist

* add compile to vae and rename some stuff for simplicity

* update reqs

* use state in eval

* GCN example with RNG + dropout

* add a bit of prefetching
2024-02-08 13:00:41 -08:00
..
bijectors.py Normalizing flow example (#133) 2024-01-13 16:58:48 -08:00
distributions.py Fix import order of normalizing_flow (#326) 2024-01-16 08:45:55 -08:00
flows.py Fix import order of normalizing_flow (#326) 2024-01-16 08:45:55 -08:00
main.py Update a few examples to use compile (#420) 2024-02-08 13:00:41 -08:00
README.md Normalizing flow example (#133) 2024-01-13 16:58:48 -08:00
requirements.txt Update a few examples to use compile (#420) 2024-02-08 13:00:41 -08:00
samples.png Normalizing flow example (#133) 2024-01-13 16:58:48 -08:00

Normalizing Flow

An example of a normalizing flow for density estimation and sampling implemented in MLX. This example implements the real NVP (non-volume preserving) model.1

Basic usage

import mlx.core as mx
from flows import RealNVP

model = RealNVP(n_transforms=8, d_params=4, d_hidden=256, n_layers=4)

x = mx.random.normal(shape=(32, 4))

# Evaluate log-density
log_prob = model.log_prob(x=x)

# Draw samples
x_samples = model.sample(sample_shape=(32, 4))

Running the example

Install the dependencies:

pip install -r requirements.txt

The example can be run with:

python main.py [--cpu]

This trains the normalizing flow on the two moons dataset and plots the result in samples.png. The optional --cpu flag can be used to run the example on the CPU, otherwise it will use the GPU by default.

For all available options, run:

python main.py --help

Results

Samples


  1. This example is from Density estimation using Real NVP, Dinh et al. (2016) ↩︎