Minor changes to README in normalizing flow example

This commit is contained in:
Siddharth Mishra-Sharma 2023-12-18 01:57:53 -05:00
parent e7879beb6e
commit 18f9646d56

View File

@ -1,8 +1,8 @@
# Normalizing flow # Normalizing Flow
Real NVP normalizing flow for density estimation and sampling from [Dinh et al. (2016)](https://arxiv.org/abs/1605.08803), implemented using `mlx`. Real NVP normalizing flow for density estimation and sampling from [Dinh et al. (2016)](https://arxiv.org/abs/1605.08803), implemented using `mlx`.
The example is written in a somewhat more object-oriented style than strictly necessary, with an eye towards extension to other use cases that could from arbitrary distributions and bijectors. The example is written in a somewhat more object-oriented style than strictly necessary, with an eye towards extension to other use cases that could potentially benefit from the use of distributions and bijectors.
## Basic usage ## Basic usage
@ -15,10 +15,10 @@ model = RealNVP(n_transforms=8, d_params=4, d_hidden=256, n_layers=4)
x = mx.random.normal(shape=(32, 4)) x = mx.random.normal(shape=(32, 4))
# Evaluate log-density # Evaluate log-density
model.log_prob(x=x) log_prob = model.log_prob(x=x)
# Draw samples # Draw samples
model.sample(sample_shape=(32, 4)) x_samples = model.sample(sample_shape=(32, 4))
``` ```
## Running the example ## Running the example
@ -31,14 +31,9 @@ pip install -r requirements.txt
The example can be run with: The example can be run with:
``` ```
python main.py python main.py [--cpu]
```
which trains the normalizing flow on the two moons dataset and plots the result in `samples.png`.
By default the example runs on the GPU. To run on the CPU, do:
```
python main.py --cpu
``` ```
which trains the normalizing flow on the two moons dataset and plots the result in `samples.png`. The optional `--cpu` flag can be used to run the example on the CPU, otherwise it will use the GPU by default.
For all available options, run: For all available options, run:
``` ```