mlx-examples/flow
2023-12-18 01:57:53 -05:00
..
bijectors.py Implement normalizing flow Real NVP example 2023-12-18 00:48:25 -05:00
distributions.py Implement normalizing flow Real NVP example 2023-12-18 00:48:25 -05:00
flows.py Implement normalizing flow Real NVP example 2023-12-18 00:48:25 -05:00
main.py Implement normalizing flow Real NVP example 2023-12-18 00:48:25 -05:00
README.md Minor changes to README in normalizing flow example 2023-12-18 01:57:53 -05:00
requirements.txt Add requirements and basic usage to normalizing flow example 2023-12-18 01:01:47 -05:00
samples.png Implement normalizing flow Real NVP example 2023-12-18 00:48:25 -05:00

Normalizing Flow

Real NVP normalizing flow for density estimation and sampling from Dinh et al. (2016), implemented using mlx.

The example is written in a somewhat more object-oriented style than strictly necessary, with an eye towards extension to other use cases that could potentially benefit from the use of distributions and bijectors.

Basic usage

import mlx.core as mx
from flows import RealNVP

model = RealNVP(n_transforms=8, d_params=4, d_hidden=256, n_layers=4)

x = mx.random.normal(shape=(32, 4))

# Evaluate log-density
log_prob = model.log_prob(x=x)

# Draw samples
x_samples = model.sample(sample_shape=(32, 4))

Running the example

Install the dependencies:

pip install -r requirements.txt

The example can be run with:

python main.py [--cpu]

which trains the normalizing flow on the two moons dataset and plots the result in samples.png. The optional --cpu flag can be used to run the example on the CPU, otherwise it will use the GPU by default.

For all available options, run:

python main.py --help

Results

Samples