mlx-examples/normalizing_flow/main.py

116 lines
3.4 KiB
Python
Raw Normal View History

# Copyright © 2023-2024 Apple Inc.
from tqdm import trange
import numpy as np
from sklearn import datasets, preprocessing
import matplotlib.pyplot as plt
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from flows import RealNVP
def get_moons_dataset(n_samples=100_000, noise=0.06):
"""Get two moons dataset with given noise level."""
x, _ = datasets.make_moons(n_samples=n_samples, noise=noise)
scaler = preprocessing.StandardScaler()
x = scaler.fit_transform(x)
return x
def main(args):
x = get_moons_dataset(n_samples=100_000, noise=args.noise)
model = RealNVP(args.n_transforms, args.d_params, args.d_hidden, args.n_layers)
mx.eval(model.parameters())
def loss_fn(model, x):
return -mx.mean(model(x))
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.Adam(learning_rate=args.learning_rate)
with trange(args.n_steps) as steps:
for step in steps:
idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch)
loss, grads = loss_and_grad_fn(model, mx.array(x[idx]))
optimizer.update(model, grads)
mx.eval(model.parameters())
steps.set_postfix(val=loss)
# Plot samples from trained flow
fig, axs = plt.subplots(1, args.n_transforms + 2, figsize=(26, 4))
cmap = plt.get_cmap("Blues")
bins = 100
# Sample from intermediate flow-transformed distributions
for n_transforms in range(args.n_transforms + 1):
x_samples = model.sample((100_000, 2), n_transforms=n_transforms)
axs[n_transforms].hist2d(x_samples[:, 0], x_samples[:, 1], bins=bins, cmap=cmap)
axs[n_transforms].set_xlim(-2, 2)
axs[n_transforms].set_ylim(-2, 2)
axs[n_transforms].set_title(
f"{n_transforms} transforms" if n_transforms > 0 else "Base distribution"
)
axs[n_transforms].set_xticklabels([])
axs[n_transforms].set_yticklabels([])
# Plot original data
axs[-1].hist2d(x[:, 0], x[:, 1], bins=bins, cmap=cmap)
axs[-1].set_xlim(-2, 2)
axs[-1].set_ylim(-2, 2)
axs[-1].set_title("Original data")
axs[-1].set_xticklabels([])
axs[-1].set_yticklabels([])
plt.tight_layout()
plt.savefig("samples.png")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--n_steps", type=int, default=5_000, help="Number of steps to train"
)
parser.add_argument("--n_batch", type=int, default=64, help="Batch size")
parser.add_argument(
"--n_transforms", type=int, default=6, help="Number of flow transforms"
)
parser.add_argument(
"--d_params", type=int, default=2, help="Dimensionality of modeled distribution"
)
parser.add_argument(
"--d_hidden",
type=int,
default=128,
help="Hidden dimensionality of coupling conditioner",
)
parser.add_argument(
"--n_layers",
type=int,
default=4,
help="Number of layers in coupling conditioner",
)
parser.add_argument(
"--learning_rate", type=float, default=3e-4, help="Learning rate"
)
parser.add_argument(
"--noise", type=float, default=0.06, help="Noise level in two moons dataset"
)
parser.add_argument("--cpu", action="store_true")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
main(args)