mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
![]() * Implement normalizing flow Real NVP example * Add requirements and basic usage to normalizing flow example * Minor changes to README in normalizing flow example * Remove trailing commas in function arguments for unified formatting in flows example * Fix minor typos, add some annotations * format + nits in README * readme fix * mov, minor changes in main, copywright * remove debug * fix * Simplified class structure in distributions; better code re-use in bijectors * Remove rogue space * change name again * nits --------- Co-authored-by: Awni Hannun <awni@apple.com> |
||
---|---|---|
.. | ||
bijectors.py | ||
distributions.py | ||
flows.py | ||
main.py | ||
README.md | ||
requirements.txt | ||
samples.png |
Normalizing Flow
An example of a normalizing flow for density estimation and sampling implemented in MLX. This example implements the real NVP (non-volume preserving) model.1
Basic usage
import mlx.core as mx
from flows import RealNVP
model = RealNVP(n_transforms=8, d_params=4, d_hidden=256, n_layers=4)
x = mx.random.normal(shape=(32, 4))
# Evaluate log-density
log_prob = model.log_prob(x=x)
# Draw samples
x_samples = model.sample(sample_shape=(32, 4))
Running the example
Install the dependencies:
pip install -r requirements.txt
The example can be run with:
python main.py [--cpu]
This trains the normalizing flow on the two moons dataset and plots the result
in samples.png
. The optional --cpu
flag can be used to run the example on
the CPU, otherwise it will use the GPU by default.
For all available options, run:
python main.py --help
Results
-
This example is from Density estimation using Real NVP, Dinh et al. (2016) ↩︎