Implement normalizing flow Real NVP example

This commit is contained in:
Siddharth Mishra-Sharma 2023-12-18 00:48:25 -05:00
parent 08e862336a
commit 5370d70122
6 changed files with 280 additions and 0 deletions

27
flow/README.md Normal file
View File

@ -0,0 +1,27 @@
# Normalizing flow
Real NVP normalizing flow from [Dinh et al. (2016)](https://arxiv.org/abs/1605.08803) implemented using `mlx`.
The example is written in a somewhat more object-oriented style than strictly necessary, with an eye towards extension to other use cases benefitting from arbitrary distributions and bijectors.
## Usage
The example can be run with
```
python main.py
```
which trains the normalizing flow on the two moons dataset and plots the result in `samples.png`.
By default the example runs on the GPU. To run on the CPU, do
```
python main.py --cpu
```
For all available options, run
```
python main.py --help
```
## Results
![Samples](./samples.png)

56
flow/bijectors.py Normal file
View File

@ -0,0 +1,56 @@
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
class Bijector:
def forward_and_log_det(self, x: mx.array) -> Tuple[mx.array, mx.array]:
raise NotImplementedError
def inverse_and_log_det(self, y: mx.array) -> Tuple[mx.array, mx.array]:
raise NotImplementedError
class AffineBijector(Bijector):
def __init__(self, shift_and_log_scale):
self.shift_and_log_scale = shift_and_log_scale
def forward_and_log_det(self, x: mx.array):
shift, log_scale = mx.split(self.shift_and_log_scale, 2, axis=-1)
y = x * mx.exp(log_scale) + shift
log_det = log_scale
return y, log_det
def inverse_and_log_det(self, y: mx.array):
shift, log_scale = mx.split(self.shift_and_log_scale, 2, axis=-1)
x = (y - shift) * mx.exp(-log_scale)
log_det = -log_scale
return x, log_det
class MaskedCoupling(Bijector):
def __init__(self, mask: mx.array, conditioner: nn.Module, bijector: Bijector):
"""Coupling layer with masking and conditioner."""
self.mask = mask
self.conditioner = conditioner
self.bijector = bijector
def forward_and_log_det(self, x: mx.array):
"""Transforms masked indices of `x` conditioned on unmasked indices using bijector."""
x_cond = mx.where(self.mask, 0.0, x)
bijector_params = self.conditioner(x_cond)
y, log_det = self.bijector(bijector_params).forward_and_log_det(x)
log_det = mx.where(self.mask, log_det, 0.0)
y = mx.where(self.mask, y, x)
return y, mx.sum(log_det, axis=-1)
def inverse_and_log_det(self, y: mx.array):
"""Transforms masked indices of `y` conditioned on unmasked indices using bijector."""
y_cond = mx.where(self.mask, 0.0, y)
bijector_params = self.conditioner(y_cond)
x, log_det = self.bijector(bijector_params).inverse_and_log_det(y)
log_det = mx.where(self.mask, log_det, 0.0)
x = mx.where(self.mask, x, y)
return x, mx.sum(log_det, axis=-1)

38
flow/distributions.py Normal file
View File

@ -0,0 +1,38 @@
from typing import Tuple, Optional, Union
import math
import mlx.core as mx
class Distribution:
def __init__(self):
pass
def sample(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None) -> mx.array:
raise NotImplementedError
def log_prob(self, x: mx.array) -> mx.array:
raise NotImplementedError
def sample_and_log_prob(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None) -> Tuple[mx.array, mx.array]:
raise NotImplementedError
def __call__(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None) -> mx.array:
return self.log_prob(self.sample(sample_shape, key=key))
class Normal(Distribution):
def __init__(self, mu: mx.array, sigma: mx.array):
super().__init__()
self.mu = mu
self.sigma = sigma
def sample(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None):
return mx.random.normal(sample_shape, key=key) * self.sigma + self.mu
def log_prob(self, x: mx.array):
return -0.5 * math.log(2 * math.pi) - mx.log(self.sigma) - 0.5 * ((x - self.mu) / self.sigma) ** 2
def sample_and_log_prob(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None):
x = self.sample(sample_shape, key=key)
return x, self.log_prob(x)

68
flow/flows.py Normal file
View File

@ -0,0 +1,68 @@
from typing import Tuple, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from bijectors import MaskedCoupling, AffineBijector
from distributions import Normal
class MLP(nn.Module):
def __init__(
self,
n_layers: int,
d_in: int,
d_hidden: int,
d_out: int,
):
super().__init__()
layer_sizes = [d_in] + [d_hidden] * n_layers + [d_out]
self.layers = [nn.Linear(idim, odim) for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])]
def __call__(self, x):
for l in self.layers[:-1]:
x = nn.gelu(l(x))
return self.layers[-1](x)
class RealNVP(nn.Module):
def __init__(
self,
n_transforms: int,
d_params: int,
d_hidden: int,
n_layers: int,
):
super().__init__()
# Alternating masks
self.mask_list = [mx.arange(d_params) % 2 == i % 2 for i in range(n_transforms)]
self.mask_list = [mask.astype(mx.bool_) for mask in self.mask_list]
self.freeze(keys=["mask_list"])
# Conditioning MLP
self.conditioner_list = [MLP(n_layers, d_params, d_hidden, 2 * d_params) for _ in range(n_transforms)]
self.bast_dist = Normal(mx.zeros(d_params), mx.ones(d_params))
def log_prob(self, x: mx.array):
log_prob = mx.zeros(x.shape[0])
for mask, conditioner in zip(self.mask_list[::-1], self.conditioner_list[::-1]):
x, ldj = MaskedCoupling(mask, conditioner, AffineBijector).inverse_and_log_det(x)
log_prob += ldj
return log_prob + self.bast_dist.log_prob(x).sum(-1)
def sample(
self,
sample_shape: Union[int, Tuple[int, ...]],
key: Optional[mx.array] = None,
n_transforms: Optional[int] = None,
):
x = self.bast_dist.sample(sample_shape, key=key)
for mask, conditioner in zip(self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]):
x, _ = MaskedCoupling(mask, conditioner, AffineBijector).forward_and_log_det(x)
return x
def __call__(self, x: mx.array):
return self.log_prob(x)

91
flow/main.py Normal file
View File

@ -0,0 +1,91 @@
from tqdm import trange
import numpy as np
from sklearn import datasets, preprocessing
import matplotlib.pyplot as plt
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from flows import RealNVP
def get_moons_dataset(n_samples=100_000, noise=0.06):
"""Get two moons dataset with given noise level."""
x, _ = datasets.make_moons(n_samples=n_samples, noise=noise)
scaler = preprocessing.StandardScaler()
x = scaler.fit_transform(x)
return x
def main(args):
x = get_moons_dataset(n_samples=100_000, noise=args.noise)
model = RealNVP(args.n_transforms, args.d_params, args.d_hidden, args.n_layers)
mx.eval(model.parameters())
def loss_fn(model, x):
return -mx.mean(model(x))
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.Adam(learning_rate=args.learning_rate)
with trange(args.n_steps) as steps:
for step in steps:
idx = np.random.permutation(x.shape[0])[: args.n_batch]
loss, grads = loss_and_grad_fn(model, mx.array(x[idx]))
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
steps.set_postfix(val=loss)
# Plot samples from trained flow
fig, axs = plt.subplots(1, args.n_transforms + 2, figsize=(26, 4))
cmap = plt.get_cmap("Blues")
bins = 100
# Sample from intermediate flow-transformed distributions
for n_transforms in range(args.n_transforms + 1):
x_samples = model.sample((100_000, 2), n_transforms=n_transforms)
axs[n_transforms].hist2d(x_samples[:, 0], x_samples[:, 1], bins=bins, cmap=cmap)
axs[n_transforms].set_xlim(-2, 2)
axs[n_transforms].set_ylim(-2, 2)
axs[n_transforms].set_title(f"{n_transforms} transforms" if n_transforms > 0 else "Base distribution")
axs[n_transforms].set_xticklabels([])
axs[n_transforms].set_yticklabels([])
# Plot original data
axs[-1].hist2d(x[:, 0], x[:, 1], bins=bins, cmap=cmap)
axs[-1].set_xlim(-2, 2)
axs[-1].set_ylim(-2, 2)
axs[-1].set_title("Original data")
axs[-1].set_xticklabels([])
axs[-1].set_yticklabels([])
plt.tight_layout()
plt.savefig("samples.png")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--n_steps", type=int, default=5_000, help="Number of steps to train")
parser.add_argument("--n_batch", type=int, default=64, help="Batch size")
parser.add_argument("--n_transforms", type=int, default=6, help="Number of flow transforms")
parser.add_argument("--d_params", type=int, default=2, help="Dimensionality of modeled distribution")
parser.add_argument("--d_hidden", type=int, default=128, help="Hidden dimensionality of coupling conditioner")
parser.add_argument("--n_layers", type=int, default=4, help="Number of layers in coupling conditioner")
parser.add_argument("--learning_rate", type=float, default=3e-4, help="Learning rate")
parser.add_argument("--noise", type=float, default=0.06, help="Noise level in two moons dataset")
parser.add_argument("--cpu", action="store_true")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
main(args)

BIN
flow/samples.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB