mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Implement normalizing flow Real NVP example
This commit is contained in:
parent
08e862336a
commit
5370d70122
27
flow/README.md
Normal file
27
flow/README.md
Normal file
@ -0,0 +1,27 @@
|
||||
# Normalizing flow
|
||||
|
||||
Real NVP normalizing flow from [Dinh et al. (2016)](https://arxiv.org/abs/1605.08803) implemented using `mlx`.
|
||||
|
||||
The example is written in a somewhat more object-oriented style than strictly necessary, with an eye towards extension to other use cases benefitting from arbitrary distributions and bijectors.
|
||||
|
||||
## Usage
|
||||
|
||||
The example can be run with
|
||||
```
|
||||
python main.py
|
||||
```
|
||||
which trains the normalizing flow on the two moons dataset and plots the result in `samples.png`.
|
||||
|
||||
By default the example runs on the GPU. To run on the CPU, do
|
||||
```
|
||||
python main.py --cpu
|
||||
```
|
||||
|
||||
For all available options, run
|
||||
```
|
||||
python main.py --help
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||

|
56
flow/bijectors.py
Normal file
56
flow/bijectors.py
Normal file
@ -0,0 +1,56 @@
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class Bijector:
|
||||
def forward_and_log_det(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
||||
raise NotImplementedError
|
||||
|
||||
def inverse_and_log_det(self, y: mx.array) -> Tuple[mx.array, mx.array]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AffineBijector(Bijector):
|
||||
def __init__(self, shift_and_log_scale):
|
||||
self.shift_and_log_scale = shift_and_log_scale
|
||||
|
||||
def forward_and_log_det(self, x: mx.array):
|
||||
shift, log_scale = mx.split(self.shift_and_log_scale, 2, axis=-1)
|
||||
y = x * mx.exp(log_scale) + shift
|
||||
log_det = log_scale
|
||||
return y, log_det
|
||||
|
||||
def inverse_and_log_det(self, y: mx.array):
|
||||
shift, log_scale = mx.split(self.shift_and_log_scale, 2, axis=-1)
|
||||
x = (y - shift) * mx.exp(-log_scale)
|
||||
log_det = -log_scale
|
||||
|
||||
return x, log_det
|
||||
|
||||
|
||||
class MaskedCoupling(Bijector):
|
||||
def __init__(self, mask: mx.array, conditioner: nn.Module, bijector: Bijector):
|
||||
"""Coupling layer with masking and conditioner."""
|
||||
self.mask = mask
|
||||
self.conditioner = conditioner
|
||||
self.bijector = bijector
|
||||
|
||||
def forward_and_log_det(self, x: mx.array):
|
||||
"""Transforms masked indices of `x` conditioned on unmasked indices using bijector."""
|
||||
x_cond = mx.where(self.mask, 0.0, x)
|
||||
bijector_params = self.conditioner(x_cond)
|
||||
y, log_det = self.bijector(bijector_params).forward_and_log_det(x)
|
||||
log_det = mx.where(self.mask, log_det, 0.0)
|
||||
y = mx.where(self.mask, y, x)
|
||||
return y, mx.sum(log_det, axis=-1)
|
||||
|
||||
def inverse_and_log_det(self, y: mx.array):
|
||||
"""Transforms masked indices of `y` conditioned on unmasked indices using bijector."""
|
||||
y_cond = mx.where(self.mask, 0.0, y)
|
||||
bijector_params = self.conditioner(y_cond)
|
||||
x, log_det = self.bijector(bijector_params).inverse_and_log_det(y)
|
||||
log_det = mx.where(self.mask, log_det, 0.0)
|
||||
x = mx.where(self.mask, x, y)
|
||||
return x, mx.sum(log_det, axis=-1)
|
38
flow/distributions.py
Normal file
38
flow/distributions.py
Normal file
@ -0,0 +1,38 @@
|
||||
from typing import Tuple, Optional, Union
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
class Distribution:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def sample(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None) -> mx.array:
|
||||
raise NotImplementedError
|
||||
|
||||
def log_prob(self, x: mx.array) -> mx.array:
|
||||
raise NotImplementedError
|
||||
|
||||
def sample_and_log_prob(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None) -> Tuple[mx.array, mx.array]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None) -> mx.array:
|
||||
return self.log_prob(self.sample(sample_shape, key=key))
|
||||
|
||||
|
||||
class Normal(Distribution):
|
||||
def __init__(self, mu: mx.array, sigma: mx.array):
|
||||
super().__init__()
|
||||
self.mu = mu
|
||||
self.sigma = sigma
|
||||
|
||||
def sample(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None):
|
||||
return mx.random.normal(sample_shape, key=key) * self.sigma + self.mu
|
||||
|
||||
def log_prob(self, x: mx.array):
|
||||
return -0.5 * math.log(2 * math.pi) - mx.log(self.sigma) - 0.5 * ((x - self.mu) / self.sigma) ** 2
|
||||
|
||||
def sample_and_log_prob(self, sample_shape: Union[int, Tuple[int, ...]], key: Optional[mx.array] = None):
|
||||
x = self.sample(sample_shape, key=key)
|
||||
return x, self.log_prob(x)
|
68
flow/flows.py
Normal file
68
flow/flows.py
Normal file
@ -0,0 +1,68 @@
|
||||
from typing import Tuple, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from bijectors import MaskedCoupling, AffineBijector
|
||||
from distributions import Normal
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_layers: int,
|
||||
d_in: int,
|
||||
d_hidden: int,
|
||||
d_out: int,
|
||||
):
|
||||
super().__init__()
|
||||
layer_sizes = [d_in] + [d_hidden] * n_layers + [d_out]
|
||||
self.layers = [nn.Linear(idim, odim) for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])]
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = nn.gelu(l(x))
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
class RealNVP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_transforms: int,
|
||||
d_params: int,
|
||||
d_hidden: int,
|
||||
n_layers: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Alternating masks
|
||||
self.mask_list = [mx.arange(d_params) % 2 == i % 2 for i in range(n_transforms)]
|
||||
self.mask_list = [mask.astype(mx.bool_) for mask in self.mask_list]
|
||||
|
||||
self.freeze(keys=["mask_list"])
|
||||
|
||||
# Conditioning MLP
|
||||
self.conditioner_list = [MLP(n_layers, d_params, d_hidden, 2 * d_params) for _ in range(n_transforms)]
|
||||
|
||||
self.bast_dist = Normal(mx.zeros(d_params), mx.ones(d_params))
|
||||
|
||||
def log_prob(self, x: mx.array):
|
||||
log_prob = mx.zeros(x.shape[0])
|
||||
for mask, conditioner in zip(self.mask_list[::-1], self.conditioner_list[::-1]):
|
||||
x, ldj = MaskedCoupling(mask, conditioner, AffineBijector).inverse_and_log_det(x)
|
||||
log_prob += ldj
|
||||
return log_prob + self.bast_dist.log_prob(x).sum(-1)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
sample_shape: Union[int, Tuple[int, ...]],
|
||||
key: Optional[mx.array] = None,
|
||||
n_transforms: Optional[int] = None,
|
||||
):
|
||||
x = self.bast_dist.sample(sample_shape, key=key)
|
||||
for mask, conditioner in zip(self.mask_list[:n_transforms], self.conditioner_list[:n_transforms]):
|
||||
x, _ = MaskedCoupling(mask, conditioner, AffineBijector).forward_and_log_det(x)
|
||||
return x
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
return self.log_prob(x)
|
91
flow/main.py
Normal file
91
flow/main.py
Normal file
@ -0,0 +1,91 @@
|
||||
from tqdm import trange
|
||||
import numpy as np
|
||||
from sklearn import datasets, preprocessing
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
|
||||
from flows import RealNVP
|
||||
|
||||
|
||||
def get_moons_dataset(n_samples=100_000, noise=0.06):
|
||||
"""Get two moons dataset with given noise level."""
|
||||
x, _ = datasets.make_moons(n_samples=n_samples, noise=noise)
|
||||
scaler = preprocessing.StandardScaler()
|
||||
x = scaler.fit_transform(x)
|
||||
return x
|
||||
|
||||
|
||||
def main(args):
|
||||
x = get_moons_dataset(n_samples=100_000, noise=args.noise)
|
||||
|
||||
model = RealNVP(args.n_transforms, args.d_params, args.d_hidden, args.n_layers)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
def loss_fn(model, x):
|
||||
return -mx.mean(model(x))
|
||||
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
optimizer = optim.Adam(learning_rate=args.learning_rate)
|
||||
|
||||
with trange(args.n_steps) as steps:
|
||||
for step in steps:
|
||||
idx = np.random.permutation(x.shape[0])[: args.n_batch]
|
||||
loss, grads = loss_and_grad_fn(model, mx.array(x[idx]))
|
||||
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
steps.set_postfix(val=loss)
|
||||
|
||||
# Plot samples from trained flow
|
||||
|
||||
fig, axs = plt.subplots(1, args.n_transforms + 2, figsize=(26, 4))
|
||||
cmap = plt.get_cmap("Blues")
|
||||
bins = 100
|
||||
|
||||
# Sample from intermediate flow-transformed distributions
|
||||
for n_transforms in range(args.n_transforms + 1):
|
||||
x_samples = model.sample((100_000, 2), n_transforms=n_transforms)
|
||||
|
||||
axs[n_transforms].hist2d(x_samples[:, 0], x_samples[:, 1], bins=bins, cmap=cmap)
|
||||
axs[n_transforms].set_xlim(-2, 2)
|
||||
axs[n_transforms].set_ylim(-2, 2)
|
||||
axs[n_transforms].set_title(f"{n_transforms} transforms" if n_transforms > 0 else "Base distribution")
|
||||
axs[n_transforms].set_xticklabels([])
|
||||
axs[n_transforms].set_yticklabels([])
|
||||
|
||||
# Plot original data
|
||||
axs[-1].hist2d(x[:, 0], x[:, 1], bins=bins, cmap=cmap)
|
||||
axs[-1].set_xlim(-2, 2)
|
||||
axs[-1].set_ylim(-2, 2)
|
||||
axs[-1].set_title("Original data")
|
||||
axs[-1].set_xticklabels([])
|
||||
axs[-1].set_yticklabels([])
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig("samples.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--n_steps", type=int, default=5_000, help="Number of steps to train")
|
||||
parser.add_argument("--n_batch", type=int, default=64, help="Batch size")
|
||||
parser.add_argument("--n_transforms", type=int, default=6, help="Number of flow transforms")
|
||||
parser.add_argument("--d_params", type=int, default=2, help="Dimensionality of modeled distribution")
|
||||
parser.add_argument("--d_hidden", type=int, default=128, help="Hidden dimensionality of coupling conditioner")
|
||||
parser.add_argument("--n_layers", type=int, default=4, help="Number of layers in coupling conditioner")
|
||||
parser.add_argument("--learning_rate", type=float, default=3e-4, help="Learning rate")
|
||||
parser.add_argument("--noise", type=float, default=0.06, help="Noise level in two moons dataset")
|
||||
parser.add_argument("--cpu", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
main(args)
|
BIN
flow/samples.png
Normal file
BIN
flow/samples.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 82 KiB |
Loading…
Reference in New Issue
Block a user