mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Normalizing flow example (#133)
* Implement normalizing flow Real NVP example * Add requirements and basic usage to normalizing flow example * Minor changes to README in normalizing flow example * Remove trailing commas in function arguments for unified formatting in flows example * Fix minor typos, add some annotations * format + nits in README * readme fix * mov, minor changes in main, copywright * remove debug * fix * Simplified class structure in distributions; better code re-use in bijectors * Remove rogue space * change name again * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
committed by
GitHub
parent
cd3cff0858
commit
19b6167d81
115
normalizing_flow/main.py
Normal file
115
normalizing_flow/main.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from tqdm import trange
|
||||
import numpy as np
|
||||
from sklearn import datasets, preprocessing
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
|
||||
from flows import RealNVP
|
||||
|
||||
|
||||
def get_moons_dataset(n_samples=100_000, noise=0.06):
|
||||
"""Get two moons dataset with given noise level."""
|
||||
x, _ = datasets.make_moons(n_samples=n_samples, noise=noise)
|
||||
scaler = preprocessing.StandardScaler()
|
||||
x = scaler.fit_transform(x)
|
||||
return x
|
||||
|
||||
|
||||
def main(args):
|
||||
x = get_moons_dataset(n_samples=100_000, noise=args.noise)
|
||||
|
||||
model = RealNVP(args.n_transforms, args.d_params, args.d_hidden, args.n_layers)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
def loss_fn(model, x):
|
||||
return -mx.mean(model(x))
|
||||
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
optimizer = optim.Adam(learning_rate=args.learning_rate)
|
||||
|
||||
with trange(args.n_steps) as steps:
|
||||
for step in steps:
|
||||
idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch)
|
||||
loss, grads = loss_and_grad_fn(model, mx.array(x[idx]))
|
||||
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
steps.set_postfix(val=loss)
|
||||
|
||||
# Plot samples from trained flow
|
||||
|
||||
fig, axs = plt.subplots(1, args.n_transforms + 2, figsize=(26, 4))
|
||||
cmap = plt.get_cmap("Blues")
|
||||
bins = 100
|
||||
|
||||
# Sample from intermediate flow-transformed distributions
|
||||
for n_transforms in range(args.n_transforms + 1):
|
||||
x_samples = model.sample((100_000, 2), n_transforms=n_transforms)
|
||||
|
||||
axs[n_transforms].hist2d(x_samples[:, 0], x_samples[:, 1], bins=bins, cmap=cmap)
|
||||
axs[n_transforms].set_xlim(-2, 2)
|
||||
axs[n_transforms].set_ylim(-2, 2)
|
||||
axs[n_transforms].set_title(
|
||||
f"{n_transforms} transforms" if n_transforms > 0 else "Base distribution"
|
||||
)
|
||||
axs[n_transforms].set_xticklabels([])
|
||||
axs[n_transforms].set_yticklabels([])
|
||||
|
||||
# Plot original data
|
||||
axs[-1].hist2d(x[:, 0], x[:, 1], bins=bins, cmap=cmap)
|
||||
axs[-1].set_xlim(-2, 2)
|
||||
axs[-1].set_ylim(-2, 2)
|
||||
axs[-1].set_title("Original data")
|
||||
axs[-1].set_xticklabels([])
|
||||
axs[-1].set_yticklabels([])
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig("samples.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--n_steps", type=int, default=5_000, help="Number of steps to train"
|
||||
)
|
||||
parser.add_argument("--n_batch", type=int, default=64, help="Batch size")
|
||||
parser.add_argument(
|
||||
"--n_transforms", type=int, default=6, help="Number of flow transforms"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--d_params", type=int, default=2, help="Dimensionality of modeled distribution"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--d_hidden",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Hidden dimensionality of coupling conditioner",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_layers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of layers in coupling conditioner",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate", type=float, default=3e-4, help="Learning rate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise", type=float, default=0.06, help="Noise level in two moons dataset"
|
||||
)
|
||||
parser.add_argument("--cpu", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
main(args)
|
||||
Reference in New Issue
Block a user