mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Implement normalizing flow Real NVP example
This commit is contained in:
parent
08e862336a
commit
5370d70122
27
flow/README.md
Normal file
27
flow/README.md
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# Normalizing flow
|
||||||
|
|
||||||
|
Real NVP normalizing flow from [Dinh et al. (2016)](https://arxiv.org/abs/1605.08803) implemented using `mlx`.
|
||||||
|
|
||||||
|
The example is written in a somewhat more object-oriented style than strictly necessary, with an eye towards extension to other use cases benefitting from arbitrary distributions and bijectors.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
The example can be run with
|
||||||
|
```
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
which trains the normalizing flow on the two moons dataset and plots the result in `samples.png`.
|
||||||
|
|
||||||
|
By default the example runs on the GPU. To run on the CPU, do
|
||||||
|
```
|
||||||
|
python main.py --cpu
|
||||||
|
```
|
||||||
|
|
||||||
|
For all available options, run
|
||||||
|
```
|
||||||
|
python main.py --help
|
||||||
|
```
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|

|
56
flow/bijectors.py
Normal file
56
flow/bijectors.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class Bijector:
|
||||||
|
def forward_and_log_det(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def inverse_and_log_det(self, y: mx.array) -> Tuple[mx.array, mx.array]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class AffineBijector(Bijector):
|
||||||
|
def __init__(self, shift_and_log_scale):
|
||||||
|
self.shift_and_log_scale = shift_and_log_scale
|
||||||
|
|
||||||
|
def forward_and_log_det(self, x: mx.array):
|
||||||
|
shift, log_scale = mx.split(self.shift_and_log_scale, 2, axis=-1)
|
||||||
|
y = x * mx.exp(log_scale) + shift
|
||||||
|
log_det = log_scale
|
||||||
|
return y, log_det
|
||||||
|
|
||||||
|
def inverse_and_log_det(self, y: mx.array):
|
||||||
|
shift, log_scale = mx.split(self.shift_and_log_scale, 2, axis=-1)
|
||||||
|
x = (y - shift) * mx.exp(-log_scale)
|
||||||
|
log_det = -log_scale
|
||||||
|
|
||||||
|
return x, log_det
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedCoupling(Bijector):
|
||||||
|
def __init__(self, mask: mx.array, conditioner: nn.Module, bijector: Bijector):
|
||||||
|
"""Coupling layer with masking and conditioner."""
|
||||||
|
self.mask = mask
|
||||||
|
self.conditioner = conditioner
|
||||||
|
self.bijector = bijector
|
||||||
|
|
||||||
|
def forward_and_log_det(self, x: mx.array):
|
||||||
|
"""Transforms masked indices of `x` conditioned on unmasked indices using bijector."""
|
||||||
|
x_cond = mx.where(self.mask, 0.0, x)
|
||||||
|
bijector_params = self.conditioner(x_cond)
|
||||||
|
y, log_det = self.bijector(bijector_params).forward_and_log_det(x)
|
||||||
|
log_det = mx.where(self.mask, log_det, 0.0)
|
||||||
|
y = mx.where(self.mask, y, x)
|
||||||
|
return y, mx.sum(log_det, axis=-1)
|
||||||
|
|
||||||
|
def inverse_and_log_det(self, y: mx.array):
|
||||||
|
"""Transforms masked indices of `y` conditioned on unmasked indices using bijector."""
|
||||||
|
y_cond = mx.where(self.mask, 0.0, y)
|
||||||
|
bijector_params = self.conditioner(y_cond)
|
||||||
|
x, log_det = self.bijector(bijector_params).inverse_and_log_det(y)
|
||||||
|
log_det = mx.where(self.mask, log_det, 0.0)
|
||||||
|
x = mx.where(self.mask, x, y)
|
||||||
|
return x, mx.sum(log_det, axis=-1)
|
38
flow/distributions.py
Normal file
38
flow/distributions.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from typing import Tuple, Optional, Union
|
||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
class Distribution:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sample(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None) -> mx.array:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def log_prob(self, x: mx.array) -> mx.array:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def sample_and_log_prob(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None) -> Tuple[mx.array, mx.array]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __call__(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None) -> mx.array:
|
||||||
|
return self.log_prob(self.sample(sample_shape, key=key))
|
||||||
|
|
||||||
|
|
||||||
|
class Normal(Distribution):
|
||||||
|
def __init__(self, mu: mx.array, sigma: mx.array):
|
||||||
|
super().__init__()
|
||||||
|
self.mu = mu
|
||||||
|
self.sigma = sigma
|
||||||
|
|
||||||
|
def sample(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None):
|
||||||
|
return mx.random.normal(sample_shape, key=key) * self.sigma + self.mu
|
||||||
|
|
||||||
|
def log_prob(self, x: mx.array):
|
||||||
|
return -0.5 * math.log(2 * math.pi) - mx.log(self.sigma) - 0.5 * ((x - self.mu) / self.sigma) ** 2
|
||||||
|
|
||||||
|
def sample_and_log_prob(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None):
|
||||||
|
x = self.sample(sample_shape, key=key)
|
||||||
|
return x, self.log_prob(x)
|
68
flow/flows.py
Normal file
68
flow/flows.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from typing import Tuple, Optional, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from bijectors import MaskedCoupling, AffineBijector
|
||||||
|
from distributions import Normal
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_layers: int,
|
||||||
|
d_in: int,
|
||||||
|
d_hidden: int,
|
||||||
|
d_out: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
layer_sizes = [d_in] + [d_hidden] * n_layers + [d_out]
|
||||||
|
self.layers = [nn.Linear(idim, odim) for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])]
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
for l in self.layers[:-1]:
|
||||||
|
x = nn.gelu(l(x))
|
||||||
|
return self.layers[-1](x)
|
||||||
|
|
||||||
|
|
||||||
|
class RealNVP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_transforms: int,
|
||||||
|
d_params: int,
|
||||||
|
d_hidden: int,
|
||||||
|
n_layers: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Alternating masks
|
||||||
|
self.mask_list = [mx.arange(d_params) % 2 == i % 2 for i in range(n_transforms)]
|
||||||
|
self.mask_list = [mask.astype(mx.bool_) for mask in self.mask_list]
|
||||||
|
|
||||||
|
self.freeze(keys=["mask_list"])
|
||||||
|
|
||||||
|
# Conditioning MLP
|
||||||
|
self.conditioner_list = [MLP(n_layers, d_params, d_hidden, 2 * d_params) for _ in range(n_transforms)]
|
||||||
|
|
||||||
|
self.bast_dist = Normal(mx.zeros(d_params), mx.ones(d_params))
|
||||||
|
|
||||||
|
def log_prob(self, x: mx.array):
|
||||||
|
log_prob = mx.zeros(x.shape[0])
|
||||||
|
for mask, conditioner in zip(self.mask_list[::-1], self.conditioner_list[::-1]):
|
||||||
|
x, ldj = MaskedCoupling(mask, conditioner, AffineBijector).inverse_and_log_det(x)
|
||||||
|
log_prob += ldj
|
||||||
|
return log_prob + self.bast_dist.log_prob(x).sum(-1)
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
sample_shape: Union[int, Tuple[int, ...]],
|
||||||
|
key: Optional[mx.array] = None,
|
||||||
|
n_transforms: Optional[int] = None,
|
||||||
|
):
|
||||||
|
x = self.bast_dist.sample(sample_shape, key=key)
|
||||||
|
for mask, conditioner in zip(self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]):
|
||||||
|
x, _ = MaskedCoupling(mask, conditioner, AffineBijector).forward_and_log_det(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array):
|
||||||
|
return self.log_prob(x)
|
91
flow/main.py
Normal file
91
flow/main.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
from tqdm import trange
|
||||||
|
import numpy as np
|
||||||
|
from sklearn import datasets, preprocessing
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
|
||||||
|
from flows import RealNVP
|
||||||
|
|
||||||
|
|
||||||
|
def get_moons_dataset(n_samples=100_000, noise=0.06):
|
||||||
|
"""Get two moons dataset with given noise level."""
|
||||||
|
x, _ = datasets.make_moons(n_samples=n_samples, noise=noise)
|
||||||
|
scaler = preprocessing.StandardScaler()
|
||||||
|
x = scaler.fit_transform(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
x = get_moons_dataset(n_samples=100_000, noise=args.noise)
|
||||||
|
|
||||||
|
model = RealNVP(args.n_transforms, args.d_params, args.d_hidden, args.n_layers)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
|
def loss_fn(model, x):
|
||||||
|
return -mx.mean(model(x))
|
||||||
|
|
||||||
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||||
|
optimizer = optim.Adam(learning_rate=args.learning_rate)
|
||||||
|
|
||||||
|
with trange(args.n_steps) as steps:
|
||||||
|
for step in steps:
|
||||||
|
idx = np.random.permutation(x.shape[0])[: args.n_batch]
|
||||||
|
loss, grads = loss_and_grad_fn(model, mx.array(x[idx]))
|
||||||
|
|
||||||
|
optimizer.update(model, grads)
|
||||||
|
mx.eval(model.parameters(), optimizer.state)
|
||||||
|
|
||||||
|
steps.set_postfix(val=loss)
|
||||||
|
|
||||||
|
# Plot samples from trained flow
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(1, args.n_transforms + 2, figsize=(26, 4))
|
||||||
|
cmap = plt.get_cmap("Blues")
|
||||||
|
bins = 100
|
||||||
|
|
||||||
|
# Sample from intermediate flow-transformed distributions
|
||||||
|
for n_transforms in range(args.n_transforms + 1):
|
||||||
|
x_samples = model.sample((100_000, 2), n_transforms=n_transforms)
|
||||||
|
|
||||||
|
axs[n_transforms].hist2d(x_samples[:, 0], x_samples[:, 1], bins=bins, cmap=cmap)
|
||||||
|
axs[n_transforms].set_xlim(-2, 2)
|
||||||
|
axs[n_transforms].set_ylim(-2, 2)
|
||||||
|
axs[n_transforms].set_title(f"{n_transforms} transforms" if n_transforms > 0 else "Base distribution")
|
||||||
|
axs[n_transforms].set_xticklabels([])
|
||||||
|
axs[n_transforms].set_yticklabels([])
|
||||||
|
|
||||||
|
# Plot original data
|
||||||
|
axs[-1].hist2d(x[:, 0], x[:, 1], bins=bins, cmap=cmap)
|
||||||
|
axs[-1].set_xlim(-2, 2)
|
||||||
|
axs[-1].set_ylim(-2, 2)
|
||||||
|
axs[-1].set_title("Original data")
|
||||||
|
axs[-1].set_xticklabels([])
|
||||||
|
axs[-1].set_yticklabels([])
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig("samples.png")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--n_steps", type=int, default=5_000, help="Number of steps to train")
|
||||||
|
parser.add_argument("--n_batch", type=int, default=64, help="Batch size")
|
||||||
|
parser.add_argument("--n_transforms", type=int, default=6, help="Number of flow transforms")
|
||||||
|
parser.add_argument("--d_params", type=int, default=2, help="Dimensionality of modeled distribution")
|
||||||
|
parser.add_argument("--d_hidden", type=int, default=128, help="Hidden dimensionality of coupling conditioner")
|
||||||
|
parser.add_argument("--n_layers", type=int, default=4, help="Number of layers in coupling conditioner")
|
||||||
|
parser.add_argument("--learning_rate", type=float, default=3e-4, help="Learning rate")
|
||||||
|
parser.add_argument("--noise", type=float, default=0.06, help="Noise level in two moons dataset")
|
||||||
|
parser.add_argument("--cpu", action="store_true")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.cpu:
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
|
||||||
|
main(args)
|
BIN
flow/samples.png
Normal file
BIN
flow/samples.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 82 KiB |
Loading…
Reference in New Issue
Block a user