mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Normalizing flow example (#133)
* Implement normalizing flow Real NVP example * Add requirements and basic usage to normalizing flow example * Minor changes to README in normalizing flow example * Remove trailing commas in function arguments for unified formatting in flows example * Fix minor typos, add some annotations * format + nits in README * readme fix * mov, minor changes in main, copywright * remove debug * fix * Simplified class structure in distributions; better code re-use in bijectors * Remove rogue space * change name again * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
cd3cff0858
commit
19b6167d81
52
normalizing_flow/README.md
Normal file
52
normalizing_flow/README.md
Normal file
@ -0,0 +1,52 @@
|
||||
# Normalizing Flow
|
||||
|
||||
An example of a normalizing flow for density estimation and sampling
|
||||
implemented in MLX. This example implements the real NVP (non-volume
|
||||
preserving) model.[^1]
|
||||
|
||||
## Basic usage
|
||||
|
||||
```python
|
||||
import mlx.core as mx
|
||||
from flows import RealNVP
|
||||
|
||||
model = RealNVP(n_transforms=8, d_params=4, d_hidden=256, n_layers=4)
|
||||
|
||||
x = mx.random.normal(shape=(32, 4))
|
||||
|
||||
# Evaluate log-density
|
||||
log_prob = model.log_prob(x=x)
|
||||
|
||||
# Draw samples
|
||||
x_samples = model.sample(sample_shape=(32, 4))
|
||||
```
|
||||
|
||||
## Running the example
|
||||
|
||||
Install the dependencies:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
The example can be run with:
|
||||
```
|
||||
python main.py [--cpu]
|
||||
```
|
||||
|
||||
This trains the normalizing flow on the two moons dataset and plots the result
|
||||
in `samples.png`. The optional `--cpu` flag can be used to run the example on
|
||||
the CPU, otherwise it will use the GPU by default.
|
||||
|
||||
For all available options, run:
|
||||
|
||||
```
|
||||
python main.py --help
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||

|
||||
|
||||
[^1]: This example is from [Density estimation using Real NVP](
|
||||
https://arxiv.org/abs/1605.08803), Dinh et al. (2016)
|
60
normalizing_flow/bijectors.py
Normal file
60
normalizing_flow/bijectors.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class Bijector:
|
||||
def forward_and_log_det(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
||||
raise NotImplementedError
|
||||
|
||||
def inverse_and_log_det(self, y: mx.array) -> Tuple[mx.array, mx.array]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AffineBijector(Bijector):
|
||||
def __init__(self, shift_and_log_scale: mx.array):
|
||||
self.shift_and_log_scale = shift_and_log_scale
|
||||
|
||||
def forward_and_log_det(self, x: mx.array):
|
||||
shift, log_scale = mx.split(self.shift_and_log_scale, 2, axis=-1)
|
||||
y = x * mx.exp(log_scale) + shift
|
||||
log_det = log_scale
|
||||
return y, log_det
|
||||
|
||||
def inverse_and_log_det(self, y: mx.array):
|
||||
shift, log_scale = mx.split(self.shift_and_log_scale, 2, axis=-1)
|
||||
x = (y - shift) * mx.exp(-log_scale)
|
||||
log_det = -log_scale
|
||||
return x, log_det
|
||||
|
||||
|
||||
class MaskedCoupling(Bijector):
|
||||
def __init__(self, mask: mx.array, conditioner: nn.Module, bijector: Bijector):
|
||||
"""Coupling layer with masking and conditioner."""
|
||||
self.mask = mask
|
||||
self.conditioner = conditioner
|
||||
self.bijector = bijector
|
||||
|
||||
def apply_mask(self, x: mx.array, func: callable):
|
||||
"""Transforms masked indices of `x` conditioned on unmasked indices using `func`."""
|
||||
x_masked = mx.where(self.mask, 0.0, x)
|
||||
bijector_params = self.conditioner(x_masked)
|
||||
y, log_det = func(bijector_params)
|
||||
log_det = mx.where(self.mask, log_det, 0.0)
|
||||
y = mx.where(self.mask, y, x)
|
||||
return y, mx.sum(log_det, axis=-1)
|
||||
|
||||
def forward_and_log_det(self, x: mx.array):
|
||||
"""Transforms masked indices of `x` conditioned on unmasked indices using bijector."""
|
||||
return self.apply_mask(
|
||||
x, lambda params: self.bijector(params).forward_and_log_det(x)
|
||||
)
|
||||
|
||||
def inverse_and_log_det(self, y: mx.array):
|
||||
"""Transforms masked indices of `y` conditioned on unmasked indices using bijector."""
|
||||
return self.apply_mask(
|
||||
y, lambda params: self.bijector(params).inverse_and_log_det(y)
|
||||
)
|
31
normalizing_flow/distributions.py
Normal file
31
normalizing_flow/distributions.py
Normal file
@ -0,0 +1,31 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from typing import Tuple, Optional, Union
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
class Normal:
|
||||
def __init__(self, mu: mx.array, sigma: mx.array):
|
||||
super().__init__()
|
||||
self.mu = mu
|
||||
self.sigma = sigma
|
||||
|
||||
def sample(
|
||||
self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None
|
||||
):
|
||||
return mx.random.normal(sample_shape, key=key) * self.sigma + self.mu
|
||||
|
||||
def log_prob(self, x: mx.array):
|
||||
return (
|
||||
-0.5 * math.log(2 * math.pi)
|
||||
- mx.log(self.sigma)
|
||||
- 0.5 * ((x - self.mu) / self.sigma) ** 2
|
||||
)
|
||||
|
||||
def sample_and_log_prob(
|
||||
self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None
|
||||
):
|
||||
x = self.sample(sample_shape, key=key)
|
||||
return x, self.log_prob(x)
|
76
normalizing_flow/flows.py
Normal file
76
normalizing_flow/flows.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from typing import Tuple, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from bijectors import MaskedCoupling, AffineBijector
|
||||
from distributions import Normal
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, n_layers: int, d_in: int, d_hidden: int, d_out: int):
|
||||
super().__init__()
|
||||
layer_sizes = [d_in] + [d_hidden] * n_layers + [d_out]
|
||||
self.layers = [
|
||||
nn.Linear(idim, odim)
|
||||
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = nn.gelu(l(x))
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
class RealNVP(nn.Module):
|
||||
def __init__(self, n_transforms: int, d_params: int, d_hidden: int, n_layers: int):
|
||||
super().__init__()
|
||||
|
||||
# Alternating masks
|
||||
self.mask_list = [mx.arange(d_params) % 2 == i % 2 for i in range(n_transforms)]
|
||||
self.mask_list = [mask.astype(mx.bool_) for mask in self.mask_list]
|
||||
|
||||
self.freeze(keys=["mask_list"])
|
||||
|
||||
# Conditioning MLP
|
||||
self.conditioner_list = [
|
||||
MLP(n_layers, d_params, d_hidden, 2 * d_params) for _ in range(n_transforms)
|
||||
]
|
||||
|
||||
self.base_dist = Normal(mx.zeros(d_params), mx.ones(d_params))
|
||||
|
||||
def log_prob(self, x: mx.array):
|
||||
"""
|
||||
Flow back to the primal Gaussian and compute log-density,
|
||||
adding the transformation log-determinant along the way.
|
||||
"""
|
||||
log_prob = mx.zeros(x.shape[0])
|
||||
for mask, conditioner in zip(self.mask_list[::-1], self.conditioner_list[::-1]):
|
||||
x, ldj = MaskedCoupling(
|
||||
mask, conditioner, AffineBijector
|
||||
).inverse_and_log_det(x)
|
||||
log_prob += ldj
|
||||
return log_prob + self.base_dist.log_prob(x).sum(-1)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
sample_shape: Union[int, Tuple[int, ...]],
|
||||
key: Optional[mx.array] = None,
|
||||
n_transforms: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Sample from the primal Gaussian and flow towards the target distribution.
|
||||
"""
|
||||
x = self.base_dist.sample(sample_shape, key=key)
|
||||
for mask, conditioner in zip(
|
||||
self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]
|
||||
):
|
||||
x, _ = MaskedCoupling(
|
||||
mask, conditioner, AffineBijector
|
||||
).forward_and_log_det(x)
|
||||
return x
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
return self.log_prob(x)
|
115
normalizing_flow/main.py
Normal file
115
normalizing_flow/main.py
Normal file
@ -0,0 +1,115 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from tqdm import trange
|
||||
import numpy as np
|
||||
from sklearn import datasets, preprocessing
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
|
||||
from flows import RealNVP
|
||||
|
||||
|
||||
def get_moons_dataset(n_samples=100_000, noise=0.06):
|
||||
"""Get two moons dataset with given noise level."""
|
||||
x, _ = datasets.make_moons(n_samples=n_samples, noise=noise)
|
||||
scaler = preprocessing.StandardScaler()
|
||||
x = scaler.fit_transform(x)
|
||||
return x
|
||||
|
||||
|
||||
def main(args):
|
||||
x = get_moons_dataset(n_samples=100_000, noise=args.noise)
|
||||
|
||||
model = RealNVP(args.n_transforms, args.d_params, args.d_hidden, args.n_layers)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
def loss_fn(model, x):
|
||||
return -mx.mean(model(x))
|
||||
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
optimizer = optim.Adam(learning_rate=args.learning_rate)
|
||||
|
||||
with trange(args.n_steps) as steps:
|
||||
for step in steps:
|
||||
idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch)
|
||||
loss, grads = loss_and_grad_fn(model, mx.array(x[idx]))
|
||||
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
steps.set_postfix(val=loss)
|
||||
|
||||
# Plot samples from trained flow
|
||||
|
||||
fig, axs = plt.subplots(1, args.n_transforms + 2, figsize=(26, 4))
|
||||
cmap = plt.get_cmap("Blues")
|
||||
bins = 100
|
||||
|
||||
# Sample from intermediate flow-transformed distributions
|
||||
for n_transforms in range(args.n_transforms + 1):
|
||||
x_samples = model.sample((100_000, 2), n_transforms=n_transforms)
|
||||
|
||||
axs[n_transforms].hist2d(x_samples[:, 0], x_samples[:, 1], bins=bins, cmap=cmap)
|
||||
axs[n_transforms].set_xlim(-2, 2)
|
||||
axs[n_transforms].set_ylim(-2, 2)
|
||||
axs[n_transforms].set_title(
|
||||
f"{n_transforms} transforms" if n_transforms > 0 else "Base distribution"
|
||||
)
|
||||
axs[n_transforms].set_xticklabels([])
|
||||
axs[n_transforms].set_yticklabels([])
|
||||
|
||||
# Plot original data
|
||||
axs[-1].hist2d(x[:, 0], x[:, 1], bins=bins, cmap=cmap)
|
||||
axs[-1].set_xlim(-2, 2)
|
||||
axs[-1].set_ylim(-2, 2)
|
||||
axs[-1].set_title("Original data")
|
||||
axs[-1].set_xticklabels([])
|
||||
axs[-1].set_yticklabels([])
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig("samples.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--n_steps", type=int, default=5_000, help="Number of steps to train"
|
||||
)
|
||||
parser.add_argument("--n_batch", type=int, default=64, help="Batch size")
|
||||
parser.add_argument(
|
||||
"--n_transforms", type=int, default=6, help="Number of flow transforms"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--d_params", type=int, default=2, help="Dimensionality of modeled distribution"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--d_hidden",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Hidden dimensionality of coupling conditioner",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_layers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of layers in coupling conditioner",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate", type=float, default=3e-4, help="Learning rate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise", type=float, default=0.06, help="Noise level in two moons dataset"
|
||||
)
|
||||
parser.add_argument("--cpu", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
main(args)
|
5
normalizing_flow/requirements.txt
Normal file
5
normalizing_flow/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
mlx
|
||||
numpy
|
||||
tqdm
|
||||
scikit-learn
|
||||
matplotlib
|
BIN
normalizing_flow/samples.png
Normal file
BIN
normalizing_flow/samples.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 71 KiB |
Loading…
Reference in New Issue
Block a user