mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
92 lines
3.2 KiB
Python
92 lines
3.2 KiB
Python
from tqdm import trange
|
|
import numpy as np
|
|
from sklearn import datasets, preprocessing
|
|
import matplotlib.pyplot as plt
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import mlx.optimizers as optim
|
|
|
|
from flows import RealNVP
|
|
|
|
|
|
def get_moons_dataset(n_samples=100_000, noise=0.06):
|
|
"""Get two moons dataset with given noise level."""
|
|
x, _ = datasets.make_moons(n_samples=n_samples, noise=noise)
|
|
scaler = preprocessing.StandardScaler()
|
|
x = scaler.fit_transform(x)
|
|
return x
|
|
|
|
|
|
def main(args):
|
|
x = get_moons_dataset(n_samples=100_000, noise=args.noise)
|
|
|
|
model = RealNVP(args.n_transforms, args.d_params, args.d_hidden, args.n_layers)
|
|
mx.eval(model.parameters())
|
|
|
|
def loss_fn(model, x):
|
|
return -mx.mean(model(x))
|
|
|
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
|
optimizer = optim.Adam(learning_rate=args.learning_rate)
|
|
|
|
with trange(args.n_steps) as steps:
|
|
for step in steps:
|
|
idx = np.random.permutation(x.shape[0])[: args.n_batch]
|
|
loss, grads = loss_and_grad_fn(model, mx.array(x[idx]))
|
|
|
|
optimizer.update(model, grads)
|
|
mx.eval(model.parameters(), optimizer.state)
|
|
|
|
steps.set_postfix(val=loss)
|
|
|
|
# Plot samples from trained flow
|
|
|
|
fig, axs = plt.subplots(1, args.n_transforms + 2, figsize=(26, 4))
|
|
cmap = plt.get_cmap("Blues")
|
|
bins = 100
|
|
|
|
# Sample from intermediate flow-transformed distributions
|
|
for n_transforms in range(args.n_transforms + 1):
|
|
x_samples = model.sample((100_000, 2), n_transforms=n_transforms)
|
|
|
|
axs[n_transforms].hist2d(x_samples[:, 0], x_samples[:, 1], bins=bins, cmap=cmap)
|
|
axs[n_transforms].set_xlim(-2, 2)
|
|
axs[n_transforms].set_ylim(-2, 2)
|
|
axs[n_transforms].set_title(f"{n_transforms} transforms" if n_transforms > 0 else "Base distribution")
|
|
axs[n_transforms].set_xticklabels([])
|
|
axs[n_transforms].set_yticklabels([])
|
|
|
|
# Plot original data
|
|
axs[-1].hist2d(x[:, 0], x[:, 1], bins=bins, cmap=cmap)
|
|
axs[-1].set_xlim(-2, 2)
|
|
axs[-1].set_ylim(-2, 2)
|
|
axs[-1].set_title("Original data")
|
|
axs[-1].set_xticklabels([])
|
|
axs[-1].set_yticklabels([])
|
|
|
|
plt.tight_layout()
|
|
plt.savefig("samples.png")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--n_steps", type=int, default=5_000, help="Number of steps to train")
|
|
parser.add_argument("--n_batch", type=int, default=64, help="Batch size")
|
|
parser.add_argument("--n_transforms", type=int, default=6, help="Number of flow transforms")
|
|
parser.add_argument("--d_params", type=int, default=2, help="Dimensionality of modeled distribution")
|
|
parser.add_argument("--d_hidden", type=int, default=128, help="Hidden dimensionality of coupling conditioner")
|
|
parser.add_argument("--n_layers", type=int, default=4, help="Number of layers in coupling conditioner")
|
|
parser.add_argument("--learning_rate", type=float, default=3e-4, help="Learning rate")
|
|
parser.add_argument("--noise", type=float, default=0.06, help="Noise level in two moons dataset")
|
|
parser.add_argument("--cpu", action="store_true")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.cpu:
|
|
mx.set_default_device(mx.cpu)
|
|
|
|
main(args)
|