2023-12-18 13:48:25 +08:00
# Normalizing flow
2023-12-18 14:01:47 +08:00
Real NVP normalizing flow for density estimation and sampling from [Dinh et al. (2016) ](https://arxiv.org/abs/1605.08803 ), implemented using `mlx` .
2023-12-18 13:48:25 +08:00
2023-12-18 14:01:47 +08:00
The example is written in a somewhat more object-oriented style than strictly necessary, with an eye towards extension to other use cases that could from arbitrary distributions and bijectors.
2023-12-18 13:48:25 +08:00
2023-12-18 14:01:47 +08:00
## Basic usage
2023-12-18 13:48:25 +08:00
2023-12-18 14:01:47 +08:00
```py
import mlx.core as mx
from flows import RealNVP
model = RealNVP(n_transforms=8, d_params=4, d_hidden=256, n_layers=4)
x = mx.random.normal(shape=(32, 4))
# Evaluate log-density
model.log_prob(x=x)
# Draw samples
model.sample(sample_shape=(32, 4))
```
## Running the example
Install the dependencies:
```
pip install -r requirements.txt
```
The example can be run with:
2023-12-18 13:48:25 +08:00
```
python main.py
```
which trains the normalizing flow on the two moons dataset and plots the result in `samples.png` .
2023-12-18 14:01:47 +08:00
By default the example runs on the GPU. To run on the CPU, do:
2023-12-18 13:48:25 +08:00
```
python main.py --cpu
```
2023-12-18 14:01:47 +08:00
For all available options, run:
2023-12-18 13:48:25 +08:00
```
python main.py --help
```
## Results
