diff --git a/flow/README.md b/flow/README.md new file mode 100644 index 00000000..dc07b8e9 --- /dev/null +++ b/flow/README.md @@ -0,0 +1,27 @@ +# Normalizing flow + +Real NVP normalizing flow from [Dinh et al. (2016)](https://arxiv.org/abs/1605.08803) implemented using `mlx`. + +The example is written in a somewhat more object-oriented style than strictly necessary, with an eye towards extension to other use cases benefitting from arbitrary distributions and bijectors. + +## Usage + +The example can be run with +``` +python main.py +``` +which trains the normalizing flow on the two moons dataset and plots the result in `samples.png`. + +By default the example runs on the GPU. To run on the CPU, do +``` +python main.py --cpu +``` + +For all available options, run +``` +python main.py --help +``` + +## Results + +![Samples](./samples.png) diff --git a/flow/bijectors.py b/flow/bijectors.py new file mode 100644 index 00000000..c02f0ba8 --- /dev/null +++ b/flow/bijectors.py @@ -0,0 +1,56 @@ +from typing import Tuple + +import mlx.core as mx +import mlx.nn as nn + + +class Bijector: + def forward_and_log_det(self, x: mx.array) -> Tuple[mx.array, mx.array]: + raise NotImplementedError + + def inverse_and_log_det(self, y: mx.array) -> Tuple[mx.array, mx.array]: + raise NotImplementedError + + +class AffineBijector(Bijector): + def __init__(self, shift_and_log_scale): + self.shift_and_log_scale = shift_and_log_scale + + def forward_and_log_det(self, x: mx.array): + shift, log_scale = mx.split(self.shift_and_log_scale, 2, axis=-1) + y = x * mx.exp(log_scale) + shift + log_det = log_scale + return y, log_det + + def inverse_and_log_det(self, y: mx.array): + shift, log_scale = mx.split(self.shift_and_log_scale, 2, axis=-1) + x = (y - shift) * mx.exp(-log_scale) + log_det = -log_scale + + return x, log_det + + +class MaskedCoupling(Bijector): + def __init__(self, mask: mx.array, conditioner: nn.Module, bijector: Bijector): + """Coupling layer with masking and conditioner.""" + self.mask = mask + self.conditioner = conditioner + self.bijector = bijector + + def forward_and_log_det(self, x: mx.array): + """Transforms masked indices of `x` conditioned on unmasked indices using bijector.""" + x_cond = mx.where(self.mask, 0.0, x) + bijector_params = self.conditioner(x_cond) + y, log_det = self.bijector(bijector_params).forward_and_log_det(x) + log_det = mx.where(self.mask, log_det, 0.0) + y = mx.where(self.mask, y, x) + return y, mx.sum(log_det, axis=-1) + + def inverse_and_log_det(self, y: mx.array): + """Transforms masked indices of `y` conditioned on unmasked indices using bijector.""" + y_cond = mx.where(self.mask, 0.0, y) + bijector_params = self.conditioner(y_cond) + x, log_det = self.bijector(bijector_params).inverse_and_log_det(y) + log_det = mx.where(self.mask, log_det, 0.0) + x = mx.where(self.mask, x, y) + return x, mx.sum(log_det, axis=-1) diff --git a/flow/distributions.py b/flow/distributions.py new file mode 100644 index 00000000..168d81c5 --- /dev/null +++ b/flow/distributions.py @@ -0,0 +1,38 @@ +from typing import Tuple, Optional, Union +import math + +import mlx.core as mx + + +class Distribution: + def __init__(self): + pass + + def sample(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None) -> mx.array: + raise NotImplementedError + + def log_prob(self, x: mx.array) -> mx.array: + raise NotImplementedError + + def sample_and_log_prob(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None) -> Tuple[mx.array, mx.array]: + raise NotImplementedError + + def __call__(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None) -> mx.array: + return self.log_prob(self.sample(sample_shape, key=key)) + + +class Normal(Distribution): + def __init__(self, mu: mx.array, sigma: mx.array): + super().__init__() + self.mu = mu + self.sigma = sigma + + def sample(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None): + return mx.random.normal(sample_shape, key=key) * self.sigma + self.mu + + def log_prob(self, x: mx.array): + return -0.5 * math.log(2 * math.pi) - mx.log(self.sigma) - 0.5 * ((x - self.mu) / self.sigma) ** 2 + + def sample_and_log_prob(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None): + x = self.sample(sample_shape, key=key) + return x, self.log_prob(x) diff --git a/flow/flows.py b/flow/flows.py new file mode 100644 index 00000000..59aed19e --- /dev/null +++ b/flow/flows.py @@ -0,0 +1,68 @@ +from typing import Tuple, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from bijectors import MaskedCoupling, AffineBijector +from distributions import Normal + + +class MLP(nn.Module): + def __init__( + self, + n_layers: int, + d_in: int, + d_hidden: int, + d_out: int, + ): + super().__init__() + layer_sizes = [d_in] + [d_hidden] * n_layers + [d_out] + self.layers = [nn.Linear(idim, odim) for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])] + + def __call__(self, x): + for l in self.layers[:-1]: + x = nn.gelu(l(x)) + return self.layers[-1](x) + + +class RealNVP(nn.Module): + def __init__( + self, + n_transforms: int, + d_params: int, + d_hidden: int, + n_layers: int, + ): + super().__init__() + + # Alternating masks + self.mask_list = [mx.arange(d_params) % 2 == i % 2 for i in range(n_transforms)] + self.mask_list = [mask.astype(mx.bool_) for mask in self.mask_list] + + self.freeze(keys=["mask_list"]) + + # Conditioning MLP + self.conditioner_list = [MLP(n_layers, d_params, d_hidden, 2 * d_params) for _ in range(n_transforms)] + + self.bast_dist = Normal(mx.zeros(d_params), mx.ones(d_params)) + + def log_prob(self, x: mx.array): + log_prob = mx.zeros(x.shape[0]) + for mask, conditioner in zip(self.mask_list[::-1], self.conditioner_list[::-1]): + x, ldj = MaskedCoupling(mask, conditioner, AffineBijector).inverse_and_log_det(x) + log_prob += ldj + return log_prob + self.bast_dist.log_prob(x).sum(-1) + + def sample( + self, + sample_shape: Union[int, Tuple[int, ...]], + key: Optional[mx.array] = None, + n_transforms: Optional[int] = None, + ): + x = self.bast_dist.sample(sample_shape, key=key) + for mask, conditioner in zip(self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]): + x, _ = MaskedCoupling(mask, conditioner, AffineBijector).forward_and_log_det(x) + return x + + def __call__(self, x: mx.array): + return self.log_prob(x) diff --git a/flow/main.py b/flow/main.py new file mode 100644 index 00000000..d14dd8aa --- /dev/null +++ b/flow/main.py @@ -0,0 +1,91 @@ +from tqdm import trange +import numpy as np +from sklearn import datasets, preprocessing +import matplotlib.pyplot as plt + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim + +from flows import RealNVP + + +def get_moons_dataset(n_samples=100_000, noise=0.06): + """Get two moons dataset with given noise level.""" + x, _ = datasets.make_moons(n_samples=n_samples, noise=noise) + scaler = preprocessing.StandardScaler() + x = scaler.fit_transform(x) + return x + + +def main(args): + x = get_moons_dataset(n_samples=100_000, noise=args.noise) + + model = RealNVP(args.n_transforms, args.d_params, args.d_hidden, args.n_layers) + mx.eval(model.parameters()) + + def loss_fn(model, x): + return -mx.mean(model(x)) + + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + optimizer = optim.Adam(learning_rate=args.learning_rate) + + with trange(args.n_steps) as steps: + for step in steps: + idx = np.random.permutation(x.shape[0])[: args.n_batch] + loss, grads = loss_and_grad_fn(model, mx.array(x[idx])) + + optimizer.update(model, grads) + mx.eval(model.parameters(), optimizer.state) + + steps.set_postfix(val=loss) + + # Plot samples from trained flow + + fig, axs = plt.subplots(1, args.n_transforms + 2, figsize=(26, 4)) + cmap = plt.get_cmap("Blues") + bins = 100 + + # Sample from intermediate flow-transformed distributions + for n_transforms in range(args.n_transforms + 1): + x_samples = model.sample((100_000, 2), n_transforms=n_transforms) + + axs[n_transforms].hist2d(x_samples[:, 0], x_samples[:, 1], bins=bins, cmap=cmap) + axs[n_transforms].set_xlim(-2, 2) + axs[n_transforms].set_ylim(-2, 2) + axs[n_transforms].set_title(f"{n_transforms} transforms" if n_transforms > 0 else "Base distribution") + axs[n_transforms].set_xticklabels([]) + axs[n_transforms].set_yticklabels([]) + + # Plot original data + axs[-1].hist2d(x[:, 0], x[:, 1], bins=bins, cmap=cmap) + axs[-1].set_xlim(-2, 2) + axs[-1].set_ylim(-2, 2) + axs[-1].set_title("Original data") + axs[-1].set_xticklabels([]) + axs[-1].set_yticklabels([]) + + plt.tight_layout() + plt.savefig("samples.png") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--n_steps", type=int, default=5_000, help="Number of steps to train") + parser.add_argument("--n_batch", type=int, default=64, help="Batch size") + parser.add_argument("--n_transforms", type=int, default=6, help="Number of flow transforms") + parser.add_argument("--d_params", type=int, default=2, help="Dimensionality of modeled distribution") + parser.add_argument("--d_hidden", type=int, default=128, help="Hidden dimensionality of coupling conditioner") + parser.add_argument("--n_layers", type=int, default=4, help="Number of layers in coupling conditioner") + parser.add_argument("--learning_rate", type=float, default=3e-4, help="Learning rate") + parser.add_argument("--noise", type=float, default=0.06, help="Noise level in two moons dataset") + parser.add_argument("--cpu", action="store_true") + + args = parser.parse_args() + + if args.cpu: + mx.set_default_device(mx.cpu) + + main(args) diff --git a/flow/samples.png b/flow/samples.png new file mode 100644 index 00000000..e8bc622c Binary files /dev/null and b/flow/samples.png differ