Normalizing flow example (#133)

* Implement normalizing flow Real NVP example

* Add requirements and basic usage to normalizing flow example

* Minor changes to README in normalizing flow example

* Remove trailing commas in function arguments for unified formatting in flows example

* Fix minor typos, add some annotations

* format + nits in README

* readme fix

* mov, minor changes in main, copywright

* remove debug

* fix

* Simplified class structure in distributions; better code re-use in bijectors

* Remove rogue space

* change name again

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Siddharth Mishra-Sharma 2024-01-13 19:58:48 -05:00 committed by GitHub
parent cd3cff0858
commit 19b6167d81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 339 additions and 0 deletions

View File

@ -0,0 +1,52 @@
# Normalizing Flow
An example of a normalizing flow for density estimation and sampling
implemented in MLX. This example implements the real NVP (non-volume
preserving) model.[^1]
## Basic usage
```python
import mlx.core as mx
from flows import RealNVP
model = RealNVP(n_transforms=8, d_params=4, d_hidden=256, n_layers=4)
x = mx.random.normal(shape=(32, 4))
# Evaluate log-density
log_prob = model.log_prob(x=x)
# Draw samples
x_samples = model.sample(sample_shape=(32, 4))
```
## Running the example
Install the dependencies:
```
pip install -r requirements.txt
```
The example can be run with:
```
python main.py [--cpu]
```
This trains the normalizing flow on the two moons dataset and plots the result
in `samples.png`. The optional `--cpu` flag can be used to run the example on
the CPU, otherwise it will use the GPU by default.
For all available options, run:
```
python main.py --help
```
## Results
![Samples](./samples.png)
[^1]: This example is from [Density estimation using Real NVP](
https://arxiv.org/abs/1605.08803), Dinh et al. (2016)

View File

@ -0,0 +1,60 @@
# Copyright © 2023-2024 Apple Inc.
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
class Bijector:
def forward_and_log_det(self, x: mx.array) -> Tuple[mx.array, mx.array]:
raise NotImplementedError
def inverse_and_log_det(self, y: mx.array) -> Tuple[mx.array, mx.array]:
raise NotImplementedError
class AffineBijector(Bijector):
def __init__(self, shift_and_log_scale: mx.array):
self.shift_and_log_scale = shift_and_log_scale
def forward_and_log_det(self, x: mx.array):
shift, log_scale = mx.split(self.shift_and_log_scale, 2, axis=-1)
y = x * mx.exp(log_scale) + shift
log_det = log_scale
return y, log_det
def inverse_and_log_det(self, y: mx.array):
shift, log_scale = mx.split(self.shift_and_log_scale, 2, axis=-1)
x = (y - shift) * mx.exp(-log_scale)
log_det = -log_scale
return x, log_det
class MaskedCoupling(Bijector):
def __init__(self, mask: mx.array, conditioner: nn.Module, bijector: Bijector):
"""Coupling layer with masking and conditioner."""
self.mask = mask
self.conditioner = conditioner
self.bijector = bijector
def apply_mask(self, x: mx.array, func: callable):
"""Transforms masked indices of `x` conditioned on unmasked indices using `func`."""
x_masked = mx.where(self.mask, 0.0, x)
bijector_params = self.conditioner(x_masked)
y, log_det = func(bijector_params)
log_det = mx.where(self.mask, log_det, 0.0)
y = mx.where(self.mask, y, x)
return y, mx.sum(log_det, axis=-1)
def forward_and_log_det(self, x: mx.array):
"""Transforms masked indices of `x` conditioned on unmasked indices using bijector."""
return self.apply_mask(
x, lambda params: self.bijector(params).forward_and_log_det(x)
)
def inverse_and_log_det(self, y: mx.array):
"""Transforms masked indices of `y` conditioned on unmasked indices using bijector."""
return self.apply_mask(
y, lambda params: self.bijector(params).inverse_and_log_det(y)
)

View File

@ -0,0 +1,31 @@
# Copyright © 2023-2024 Apple Inc.
from typing import Tuple, Optional, Union
import math
import mlx.core as mx
class Normal:
def __init__(self, mu: mx.array, sigma: mx.array):
super().__init__()
self.mu = mu
self.sigma = sigma
def sample(
self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None
):
return mx.random.normal(sample_shape, key=key) * self.sigma + self.mu
def log_prob(self, x: mx.array):
return (
-0.5 * math.log(2 * math.pi)
- mx.log(self.sigma)
- 0.5 * ((x - self.mu) / self.sigma) ** 2
)
def sample_and_log_prob(
self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None
):
x = self.sample(sample_shape, key=key)
return x, self.log_prob(x)

76
normalizing_flow/flows.py Normal file
View File

@ -0,0 +1,76 @@
# Copyright © 2023-2024 Apple Inc.
from typing import Tuple, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from bijectors import MaskedCoupling, AffineBijector
from distributions import Normal
class MLP(nn.Module):
def __init__(self, n_layers: int, d_in: int, d_hidden: int, d_out: int):
super().__init__()
layer_sizes = [d_in] + [d_hidden] * n_layers + [d_out]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
def __call__(self, x):
for l in self.layers[:-1]:
x = nn.gelu(l(x))
return self.layers[-1](x)
class RealNVP(nn.Module):
def __init__(self, n_transforms: int, d_params: int, d_hidden: int, n_layers: int):
super().__init__()
# Alternating masks
self.mask_list = [mx.arange(d_params) % 2 == i % 2 for i in range(n_transforms)]
self.mask_list = [mask.astype(mx.bool_) for mask in self.mask_list]
self.freeze(keys=["mask_list"])
# Conditioning MLP
self.conditioner_list = [
MLP(n_layers, d_params, d_hidden, 2 * d_params) for _ in range(n_transforms)
]
self.base_dist = Normal(mx.zeros(d_params), mx.ones(d_params))
def log_prob(self, x: mx.array):
"""
Flow back to the primal Gaussian and compute log-density,
adding the transformation log-determinant along the way.
"""
log_prob = mx.zeros(x.shape[0])
for mask, conditioner in zip(self.mask_list[::-1], self.conditioner_list[::-1]):
x, ldj = MaskedCoupling(
mask, conditioner, AffineBijector
).inverse_and_log_det(x)
log_prob += ldj
return log_prob + self.base_dist.log_prob(x).sum(-1)
def sample(
self,
sample_shape: Union[int, Tuple[int, ...]],
key: Optional[mx.array] = None,
n_transforms: Optional[int] = None,
):
"""
Sample from the primal Gaussian and flow towards the target distribution.
"""
x = self.base_dist.sample(sample_shape, key=key)
for mask, conditioner in zip(
self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]
):
x, _ = MaskedCoupling(
mask, conditioner, AffineBijector
).forward_and_log_det(x)
return x
def __call__(self, x: mx.array):
return self.log_prob(x)

115
normalizing_flow/main.py Normal file
View File

@ -0,0 +1,115 @@
# Copyright © 2023-2024 Apple Inc.
from tqdm import trange
import numpy as np
from sklearn import datasets, preprocessing
import matplotlib.pyplot as plt
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from flows import RealNVP
def get_moons_dataset(n_samples=100_000, noise=0.06):
"""Get two moons dataset with given noise level."""
x, _ = datasets.make_moons(n_samples=n_samples, noise=noise)
scaler = preprocessing.StandardScaler()
x = scaler.fit_transform(x)
return x
def main(args):
x = get_moons_dataset(n_samples=100_000, noise=args.noise)
model = RealNVP(args.n_transforms, args.d_params, args.d_hidden, args.n_layers)
mx.eval(model.parameters())
def loss_fn(model, x):
return -mx.mean(model(x))
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.Adam(learning_rate=args.learning_rate)
with trange(args.n_steps) as steps:
for step in steps:
idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch)
loss, grads = loss_and_grad_fn(model, mx.array(x[idx]))
optimizer.update(model, grads)
mx.eval(model.parameters())
steps.set_postfix(val=loss)
# Plot samples from trained flow
fig, axs = plt.subplots(1, args.n_transforms + 2, figsize=(26, 4))
cmap = plt.get_cmap("Blues")
bins = 100
# Sample from intermediate flow-transformed distributions
for n_transforms in range(args.n_transforms + 1):
x_samples = model.sample((100_000, 2), n_transforms=n_transforms)
axs[n_transforms].hist2d(x_samples[:, 0], x_samples[:, 1], bins=bins, cmap=cmap)
axs[n_transforms].set_xlim(-2, 2)
axs[n_transforms].set_ylim(-2, 2)
axs[n_transforms].set_title(
f"{n_transforms} transforms" if n_transforms > 0 else "Base distribution"
)
axs[n_transforms].set_xticklabels([])
axs[n_transforms].set_yticklabels([])
# Plot original data
axs[-1].hist2d(x[:, 0], x[:, 1], bins=bins, cmap=cmap)
axs[-1].set_xlim(-2, 2)
axs[-1].set_ylim(-2, 2)
axs[-1].set_title("Original data")
axs[-1].set_xticklabels([])
axs[-1].set_yticklabels([])
plt.tight_layout()
plt.savefig("samples.png")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--n_steps", type=int, default=5_000, help="Number of steps to train"
)
parser.add_argument("--n_batch", type=int, default=64, help="Batch size")
parser.add_argument(
"--n_transforms", type=int, default=6, help="Number of flow transforms"
)
parser.add_argument(
"--d_params", type=int, default=2, help="Dimensionality of modeled distribution"
)
parser.add_argument(
"--d_hidden",
type=int,
default=128,
help="Hidden dimensionality of coupling conditioner",
)
parser.add_argument(
"--n_layers",
type=int,
default=4,
help="Number of layers in coupling conditioner",
)
parser.add_argument(
"--learning_rate", type=float, default=3e-4, help="Learning rate"
)
parser.add_argument(
"--noise", type=float, default=0.06, help="Noise level in two moons dataset"
)
parser.add_argument("--cpu", action="store_true")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
main(args)

View File

@ -0,0 +1,5 @@
mlx
numpy
tqdm
scikit-learn
matplotlib

Binary file not shown.

After

Width:  |  Height:  |  Size: 71 KiB