diff --git a/normalizing_flow/README.md b/normalizing_flow/README.md new file mode 100644 index 00000000..5d951d6e --- /dev/null +++ b/normalizing_flow/README.md @@ -0,0 +1,52 @@ +# Normalizing Flow + +An example of a normalizing flow for density estimation and sampling +implemented in MLX. This example implements the real NVP (non-volume +preserving) model.[^1] + +## Basic usage + +```python +import mlx.core as mx +from flows import RealNVP + +model = RealNVP(n_transforms=8, d_params=4, d_hidden=256, n_layers=4) + +x = mx.random.normal(shape=(32, 4)) + +# Evaluate log-density +log_prob = model.log_prob(x=x) + +# Draw samples +x_samples = model.sample(sample_shape=(32, 4)) +``` + +## Running the example + +Install the dependencies: + +``` +pip install -r requirements.txt +``` + +The example can be run with: +``` +python main.py [--cpu] +``` + +This trains the normalizing flow on the two moons dataset and plots the result +in `samples.png`. The optional `--cpu` flag can be used to run the example on +the CPU, otherwise it will use the GPU by default. + +For all available options, run: + +``` +python main.py --help +``` + +## Results + +![Samples](./samples.png) + +[^1]: This example is from [Density estimation using Real NVP]( + https://arxiv.org/abs/1605.08803), Dinh et al. (2016) diff --git a/normalizing_flow/bijectors.py b/normalizing_flow/bijectors.py new file mode 100644 index 00000000..d3c4f8f0 --- /dev/null +++ b/normalizing_flow/bijectors.py @@ -0,0 +1,60 @@ +# Copyright © 2023-2024 Apple Inc. + +from typing import Tuple + +import mlx.core as mx +import mlx.nn as nn + + +class Bijector: + def forward_and_log_det(self, x: mx.array) -> Tuple[mx.array, mx.array]: + raise NotImplementedError + + def inverse_and_log_det(self, y: mx.array) -> Tuple[mx.array, mx.array]: + raise NotImplementedError + + +class AffineBijector(Bijector): + def __init__(self, shift_and_log_scale: mx.array): + self.shift_and_log_scale = shift_and_log_scale + + def forward_and_log_det(self, x: mx.array): + shift, log_scale = mx.split(self.shift_and_log_scale, 2, axis=-1) + y = x * mx.exp(log_scale) + shift + log_det = log_scale + return y, log_det + + def inverse_and_log_det(self, y: mx.array): + shift, log_scale = mx.split(self.shift_and_log_scale, 2, axis=-1) + x = (y - shift) * mx.exp(-log_scale) + log_det = -log_scale + return x, log_det + + +class MaskedCoupling(Bijector): + def __init__(self, mask: mx.array, conditioner: nn.Module, bijector: Bijector): + """Coupling layer with masking and conditioner.""" + self.mask = mask + self.conditioner = conditioner + self.bijector = bijector + + def apply_mask(self, x: mx.array, func: callable): + """Transforms masked indices of `x` conditioned on unmasked indices using `func`.""" + x_masked = mx.where(self.mask, 0.0, x) + bijector_params = self.conditioner(x_masked) + y, log_det = func(bijector_params) + log_det = mx.where(self.mask, log_det, 0.0) + y = mx.where(self.mask, y, x) + return y, mx.sum(log_det, axis=-1) + + def forward_and_log_det(self, x: mx.array): + """Transforms masked indices of `x` conditioned on unmasked indices using bijector.""" + return self.apply_mask( + x, lambda params: self.bijector(params).forward_and_log_det(x) + ) + + def inverse_and_log_det(self, y: mx.array): + """Transforms masked indices of `y` conditioned on unmasked indices using bijector.""" + return self.apply_mask( + y, lambda params: self.bijector(params).inverse_and_log_det(y) + ) diff --git a/normalizing_flow/distributions.py b/normalizing_flow/distributions.py new file mode 100644 index 00000000..2199a4a8 --- /dev/null +++ b/normalizing_flow/distributions.py @@ -0,0 +1,31 @@ +# Copyright © 2023-2024 Apple Inc. + +from typing import Tuple, Optional, Union +import math + +import mlx.core as mx + + +class Normal: + def __init__(self, mu: mx.array, sigma: mx.array): + super().__init__() + self.mu = mu + self.sigma = sigma + + def sample( + self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None + ): + return mx.random.normal(sample_shape, key=key) * self.sigma + self.mu + + def log_prob(self, x: mx.array): + return ( + -0.5 * math.log(2 * math.pi) + - mx.log(self.sigma) + - 0.5 * ((x - self.mu) / self.sigma) ** 2 + ) + + def sample_and_log_prob( + self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None + ): + x = self.sample(sample_shape, key=key) + return x, self.log_prob(x) diff --git a/normalizing_flow/flows.py b/normalizing_flow/flows.py new file mode 100644 index 00000000..0fbd550b --- /dev/null +++ b/normalizing_flow/flows.py @@ -0,0 +1,76 @@ +# Copyright © 2023-2024 Apple Inc. + +from typing import Tuple, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from bijectors import MaskedCoupling, AffineBijector +from distributions import Normal + + +class MLP(nn.Module): + def __init__(self, n_layers: int, d_in: int, d_hidden: int, d_out: int): + super().__init__() + layer_sizes = [d_in] + [d_hidden] * n_layers + [d_out] + self.layers = [ + nn.Linear(idim, odim) + for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]) + ] + + def __call__(self, x): + for l in self.layers[:-1]: + x = nn.gelu(l(x)) + return self.layers[-1](x) + + +class RealNVP(nn.Module): + def __init__(self, n_transforms: int, d_params: int, d_hidden: int, n_layers: int): + super().__init__() + + # Alternating masks + self.mask_list = [mx.arange(d_params) % 2 == i % 2 for i in range(n_transforms)] + self.mask_list = [mask.astype(mx.bool_) for mask in self.mask_list] + + self.freeze(keys=["mask_list"]) + + # Conditioning MLP + self.conditioner_list = [ + MLP(n_layers, d_params, d_hidden, 2 * d_params) for _ in range(n_transforms) + ] + + self.base_dist = Normal(mx.zeros(d_params), mx.ones(d_params)) + + def log_prob(self, x: mx.array): + """ + Flow back to the primal Gaussian and compute log-density, + adding the transformation log-determinant along the way. + """ + log_prob = mx.zeros(x.shape[0]) + for mask, conditioner in zip(self.mask_list[::-1], self.conditioner_list[::-1]): + x, ldj = MaskedCoupling( + mask, conditioner, AffineBijector + ).inverse_and_log_det(x) + log_prob += ldj + return log_prob + self.base_dist.log_prob(x).sum(-1) + + def sample( + self, + sample_shape: Union[int, Tuple[int, ...]], + key: Optional[mx.array] = None, + n_transforms: Optional[int] = None, + ): + """ + Sample from the primal Gaussian and flow towards the target distribution. + """ + x = self.base_dist.sample(sample_shape, key=key) + for mask, conditioner in zip( + self.mask_list[:n_transforms], self.conditioner_list[:n_transforms] + ): + x, _ = MaskedCoupling( + mask, conditioner, AffineBijector + ).forward_and_log_det(x) + return x + + def __call__(self, x: mx.array): + return self.log_prob(x) diff --git a/normalizing_flow/main.py b/normalizing_flow/main.py new file mode 100644 index 00000000..0d3cfe56 --- /dev/null +++ b/normalizing_flow/main.py @@ -0,0 +1,115 @@ +# Copyright © 2023-2024 Apple Inc. + +from tqdm import trange +import numpy as np +from sklearn import datasets, preprocessing +import matplotlib.pyplot as plt + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim + +from flows import RealNVP + + +def get_moons_dataset(n_samples=100_000, noise=0.06): + """Get two moons dataset with given noise level.""" + x, _ = datasets.make_moons(n_samples=n_samples, noise=noise) + scaler = preprocessing.StandardScaler() + x = scaler.fit_transform(x) + return x + + +def main(args): + x = get_moons_dataset(n_samples=100_000, noise=args.noise) + + model = RealNVP(args.n_transforms, args.d_params, args.d_hidden, args.n_layers) + mx.eval(model.parameters()) + + def loss_fn(model, x): + return -mx.mean(model(x)) + + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + optimizer = optim.Adam(learning_rate=args.learning_rate) + + with trange(args.n_steps) as steps: + for step in steps: + idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch) + loss, grads = loss_and_grad_fn(model, mx.array(x[idx])) + + optimizer.update(model, grads) + mx.eval(model.parameters()) + + steps.set_postfix(val=loss) + + # Plot samples from trained flow + + fig, axs = plt.subplots(1, args.n_transforms + 2, figsize=(26, 4)) + cmap = plt.get_cmap("Blues") + bins = 100 + + # Sample from intermediate flow-transformed distributions + for n_transforms in range(args.n_transforms + 1): + x_samples = model.sample((100_000, 2), n_transforms=n_transforms) + + axs[n_transforms].hist2d(x_samples[:, 0], x_samples[:, 1], bins=bins, cmap=cmap) + axs[n_transforms].set_xlim(-2, 2) + axs[n_transforms].set_ylim(-2, 2) + axs[n_transforms].set_title( + f"{n_transforms} transforms" if n_transforms > 0 else "Base distribution" + ) + axs[n_transforms].set_xticklabels([]) + axs[n_transforms].set_yticklabels([]) + + # Plot original data + axs[-1].hist2d(x[:, 0], x[:, 1], bins=bins, cmap=cmap) + axs[-1].set_xlim(-2, 2) + axs[-1].set_ylim(-2, 2) + axs[-1].set_title("Original data") + axs[-1].set_xticklabels([]) + axs[-1].set_yticklabels([]) + + plt.tight_layout() + plt.savefig("samples.png") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--n_steps", type=int, default=5_000, help="Number of steps to train" + ) + parser.add_argument("--n_batch", type=int, default=64, help="Batch size") + parser.add_argument( + "--n_transforms", type=int, default=6, help="Number of flow transforms" + ) + parser.add_argument( + "--d_params", type=int, default=2, help="Dimensionality of modeled distribution" + ) + parser.add_argument( + "--d_hidden", + type=int, + default=128, + help="Hidden dimensionality of coupling conditioner", + ) + parser.add_argument( + "--n_layers", + type=int, + default=4, + help="Number of layers in coupling conditioner", + ) + parser.add_argument( + "--learning_rate", type=float, default=3e-4, help="Learning rate" + ) + parser.add_argument( + "--noise", type=float, default=0.06, help="Noise level in two moons dataset" + ) + parser.add_argument("--cpu", action="store_true") + + args = parser.parse_args() + + if args.cpu: + mx.set_default_device(mx.cpu) + + main(args) diff --git a/normalizing_flow/requirements.txt b/normalizing_flow/requirements.txt new file mode 100644 index 00000000..5b335764 --- /dev/null +++ b/normalizing_flow/requirements.txt @@ -0,0 +1,5 @@ +mlx +numpy +tqdm +scikit-learn +matplotlib \ No newline at end of file diff --git a/normalizing_flow/samples.png b/normalizing_flow/samples.png new file mode 100644 index 00000000..9a26891b Binary files /dev/null and b/normalizing_flow/samples.png differ