From e7879beb6eefdedce6a82c790461b73cc0227c8e Mon Sep 17 00:00:00 2001 From: Siddharth Mishra-Sharma Date: Mon, 18 Dec 2023 01:01:47 -0500 Subject: [PATCH] Add requirements and basic usage to normalizing flow example --- flow/README.md | 35 +++++++++++++++++++++++++++++------ flow/requirements.txt | 4 ++++ 2 files changed, 33 insertions(+), 6 deletions(-) create mode 100644 flow/requirements.txt diff --git a/flow/README.md b/flow/README.md index dc07b8e9..aa5850f8 100644 --- a/flow/README.md +++ b/flow/README.md @@ -1,23 +1,46 @@ # Normalizing flow -Real NVP normalizing flow from [Dinh et al. (2016)](https://arxiv.org/abs/1605.08803) implemented using `mlx`. +Real NVP normalizing flow for density estimation and sampling from [Dinh et al. (2016)](https://arxiv.org/abs/1605.08803), implemented using `mlx`. -The example is written in a somewhat more object-oriented style than strictly necessary, with an eye towards extension to other use cases benefitting from arbitrary distributions and bijectors. +The example is written in a somewhat more object-oriented style than strictly necessary, with an eye towards extension to other use cases that could from arbitrary distributions and bijectors. -## Usage +## Basic usage -The example can be run with +```py +import mlx.core as mx +from flows import RealNVP + +model = RealNVP(n_transforms=8, d_params=4, d_hidden=256, n_layers=4) + +x = mx.random.normal(shape=(32, 4)) + +# Evaluate log-density +model.log_prob(x=x) + +# Draw samples +model.sample(sample_shape=(32, 4)) +``` + +## Running the example + +Install the dependencies: + +``` +pip install -r requirements.txt +``` + +The example can be run with: ``` python main.py ``` which trains the normalizing flow on the two moons dataset and plots the result in `samples.png`. -By default the example runs on the GPU. To run on the CPU, do +By default the example runs on the GPU. To run on the CPU, do: ``` python main.py --cpu ``` -For all available options, run +For all available options, run: ``` python main.py --help ``` diff --git a/flow/requirements.txt b/flow/requirements.txt new file mode 100644 index 00000000..dd0991d4 --- /dev/null +++ b/flow/requirements.txt @@ -0,0 +1,4 @@ +mlx +numpy +tqdm +scikit-learn \ No newline at end of file