Normalizing flow example (#133)

* Implement normalizing flow Real NVP example

* Add requirements and basic usage to normalizing flow example

* Minor changes to README in normalizing flow example

* Remove trailing commas in function arguments for unified formatting in flows example

* Fix minor typos, add some annotations

* format + nits in README

* readme fix

* mov, minor changes in main, copywright

* remove debug

* fix

* Simplified class structure in distributions; better code re-use in bijectors

* Remove rogue space

* change name again

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Siddharth Mishra-Sharma
2024-01-13 19:58:48 -05:00
committed by GitHub
parent cd3cff0858
commit 19b6167d81
7 changed files with 339 additions and 0 deletions

115
normalizing_flow/main.py Normal file
View File

@@ -0,0 +1,115 @@
# Copyright © 2023-2024 Apple Inc.
from tqdm import trange
import numpy as np
from sklearn import datasets, preprocessing
import matplotlib.pyplot as plt
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from flows import RealNVP
def get_moons_dataset(n_samples=100_000, noise=0.06):
"""Get two moons dataset with given noise level."""
x, _ = datasets.make_moons(n_samples=n_samples, noise=noise)
scaler = preprocessing.StandardScaler()
x = scaler.fit_transform(x)
return x
def main(args):
x = get_moons_dataset(n_samples=100_000, noise=args.noise)
model = RealNVP(args.n_transforms, args.d_params, args.d_hidden, args.n_layers)
mx.eval(model.parameters())
def loss_fn(model, x):
return -mx.mean(model(x))
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.Adam(learning_rate=args.learning_rate)
with trange(args.n_steps) as steps:
for step in steps:
idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch)
loss, grads = loss_and_grad_fn(model, mx.array(x[idx]))
optimizer.update(model, grads)
mx.eval(model.parameters())
steps.set_postfix(val=loss)
# Plot samples from trained flow
fig, axs = plt.subplots(1, args.n_transforms + 2, figsize=(26, 4))
cmap = plt.get_cmap("Blues")
bins = 100
# Sample from intermediate flow-transformed distributions
for n_transforms in range(args.n_transforms + 1):
x_samples = model.sample((100_000, 2), n_transforms=n_transforms)
axs[n_transforms].hist2d(x_samples[:, 0], x_samples[:, 1], bins=bins, cmap=cmap)
axs[n_transforms].set_xlim(-2, 2)
axs[n_transforms].set_ylim(-2, 2)
axs[n_transforms].set_title(
f"{n_transforms} transforms" if n_transforms > 0 else "Base distribution"
)
axs[n_transforms].set_xticklabels([])
axs[n_transforms].set_yticklabels([])
# Plot original data
axs[-1].hist2d(x[:, 0], x[:, 1], bins=bins, cmap=cmap)
axs[-1].set_xlim(-2, 2)
axs[-1].set_ylim(-2, 2)
axs[-1].set_title("Original data")
axs[-1].set_xticklabels([])
axs[-1].set_yticklabels([])
plt.tight_layout()
plt.savefig("samples.png")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--n_steps", type=int, default=5_000, help="Number of steps to train"
)
parser.add_argument("--n_batch", type=int, default=64, help="Batch size")
parser.add_argument(
"--n_transforms", type=int, default=6, help="Number of flow transforms"
)
parser.add_argument(
"--d_params", type=int, default=2, help="Dimensionality of modeled distribution"
)
parser.add_argument(
"--d_hidden",
type=int,
default=128,
help="Hidden dimensionality of coupling conditioner",
)
parser.add_argument(
"--n_layers",
type=int,
default=4,
help="Number of layers in coupling conditioner",
)
parser.add_argument(
"--learning_rate", type=float, default=3e-4, help="Learning rate"
)
parser.add_argument(
"--noise", type=float, default=0.06, help="Noise level in two moons dataset"
)
parser.add_argument("--cpu", action="store_true")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
main(args)