diff --git a/cifar/dataset.py b/cifar/dataset.py index 32918a4b..22b229f8 100644 --- a/cifar/dataset.py +++ b/cifar/dataset.py @@ -1,14 +1,12 @@ -import math - -import mlx.core as mx +import numpy as np from mlx.data.datasets import load_cifar10 def get_cifar10(batch_size, root=None): tr = load_cifar10(root=root) - mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)) - std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3)) + mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)) + std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3)) def normalize(x): x = x.astype("float32") / 255.0 @@ -23,6 +21,7 @@ def get_cifar10(batch_size, root=None): .image_random_crop("image", 32, 32) .key_transform("image", normalize) .batch(batch_size) + .prefetch(4, 4) ) test = load_cifar10(root=root, train=False) diff --git a/cifar/main.py b/cifar/main.py index f2174772..378bc424 100644 --- a/cifar/main.py +++ b/cifar/main.py @@ -1,5 +1,6 @@ import argparse import time +from functools import partial import mlx.core as mx import mlx.nn as nn @@ -33,19 +34,25 @@ def train_epoch(model, train_iter, optimizer, epoch): acc = mx.mean(mx.argmax(output, axis=1) == tgt) return loss, acc - train_step_fn = nn.value_and_grad(model, train_step) - losses = [] accs = [] samples_per_sec = [] + state = [model.state, optimizer.state] + + @partial(mx.compile, inputs=state, outputs=state) + def step(inp, tgt): + train_step_fn = nn.value_and_grad(model, train_step) + (loss, acc), grads = train_step_fn(model, inp, tgt) + optimizer.update(model, grads) + return loss, acc + for batch_counter, batch in enumerate(train_iter): x = mx.array(batch["image"]) y = mx.array(batch["label"]) tic = time.perf_counter() - (loss, acc), grads = train_step_fn(model, x, y) - optimizer.update(model, grads) - mx.eval(model.parameters(), optimizer.state) + loss, acc = step(x, y) + mx.eval(state) toc = time.perf_counter() loss = loss.item() acc = acc.item() diff --git a/cifar/requirements.txt b/cifar/requirements.txt index e4764d13..03fc57c1 100644 --- a/cifar/requirements.txt +++ b/cifar/requirements.txt @@ -1,3 +1,3 @@ -mlx>=0.0.9 +mlx>=0.2 mlx-data -numpy \ No newline at end of file +numpy diff --git a/cvae/dataset.py b/cvae/dataset.py index 3af2ca32..d89d8b52 100644 --- a/cvae/dataset.py +++ b/cvae/dataset.py @@ -23,6 +23,7 @@ def mnist(batch_size, img_size, root=None): .image_resize("image", h=img_size[0], w=img_size[1]) .key_transform("image", normalize) .batch(batch_size) + .prefetch(4, 4) ) # iterator over test set diff --git a/cvae/main.py b/cvae/main.py index 7c395a2d..78ac9b4a 100644 --- a/cvae/main.py +++ b/cvae/main.py @@ -2,14 +2,15 @@ import argparse import time +from functools import partial from pathlib import Path import dataset import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim -import model import numpy as np +import vae from mlx.utils import tree_flatten from PIL import Image @@ -53,44 +54,6 @@ def loss_fn(model, X): return recon_loss + kl_div -def train_epoch(model, data, optimizer, epoch): - loss_acc = 0.0 - throughput_acc = 0.0 - loss_and_grad_fn = nn.value_and_grad(model, loss_fn) - - # Iterate over training batches - for batch_count, batch in enumerate(data): - X = mx.array(batch["image"]) - - throughput_tic = time.perf_counter() - - # Forward pass + backward pass + update - loss, grads = loss_and_grad_fn(model, X) - optimizer.update(model, grads) - - # Evaluate updated model parameters - mx.eval(model.parameters(), optimizer.state) - - throughput_toc = time.perf_counter() - throughput_acc += X.shape[0] / (throughput_toc - throughput_tic) - loss_acc += loss.item() - - if batch_count > 0 and (batch_count % 10 == 0): - print( - " | ".join( - [ - f"Epoch {epoch:4d}", - f"Loss {(loss_acc / batch_count):10.2f}", - f"Throughput {(throughput_acc / batch_count):8.2f} im/s", - f"Batch {batch_count:5d}", - ] - ), - end="\r", - ) - - return loss_acc, throughput_acc, batch_count - - def reconstruct(model, batch, out_file): # Reconstruct a single batch only images = mx.array(batch["image"]) @@ -127,10 +90,10 @@ def main(args): save_dir.mkdir(parents=True, exist_ok=True) # Load the model - vae = model.CVAE(args.latent_dims, img_size, args.max_filters) - mx.eval(vae.parameters()) + model = vae.CVAE(args.latent_dims, img_size, args.max_filters) + mx.eval(model.parameters()) - num_params = sum(x.size for _, x in tree_flatten(vae.trainable_parameters())) + num_params = sum(x.size for _, x in tree_flatten(model.trainable_parameters())) print("Number of trainable params: {:0.04f} M".format(num_params / 1e6)) optimizer = optim.AdamW(learning_rate=args.lr) @@ -139,19 +102,53 @@ def main(args): train_batch = next(train_iter) test_batch = next(test_iter) + state = [model.state, optimizer.state] + + @partial(mx.compile, inputs=state, outputs=state) + def step(X): + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + loss, grads = loss_and_grad_fn(model, X) + optimizer.update(model, grads) + return loss + for e in range(1, args.epochs + 1): # Reset iterators and stats at the beginning of each epoch train_iter.reset() - vae.train() + model.train() # Train one epoch tic = time.perf_counter() - loss_acc, throughput_acc, batch_count = train_epoch( - vae, train_iter, optimizer, e - ) - toc = time.perf_counter() + loss_acc = 0.0 + throughput_acc = 0.0 - vae.eval() + # Iterate over training batches + for batch_count, batch in enumerate(train_iter): + X = mx.array(batch["image"]) + throughput_tic = time.perf_counter() + + # Forward pass + backward pass + update + loss = step(X) + + # Evaluate updated model parameters + mx.eval(state) + + throughput_toc = time.perf_counter() + throughput_acc += X.shape[0] / (throughput_toc - throughput_tic) + loss_acc += loss.item() + + if batch_count > 0 and (batch_count % 10 == 0): + print( + " | ".join( + [ + f"Epoch {e:4d}", + f"Loss {(loss_acc / batch_count):10.2f}", + f"Throughput {(throughput_acc / batch_count):8.2f} im/s", + f"Batch {batch_count:5d}", + ] + ), + end="\r", + ) + toc = time.perf_counter() print( " | ".join( @@ -163,14 +160,17 @@ def main(args): ] ) ) + + model.eval() + # Reconstruct a batch of training and test images - reconstruct(vae, train_batch, save_dir / f"train_{e:03d}.png") - reconstruct(vae, test_batch, save_dir / f"test_{e:03d}.png") + reconstruct(model, train_batch, save_dir / f"train_{e:03d}.png") + reconstruct(model, test_batch, save_dir / f"test_{e:03d}.png") # Generate images - generate(vae, save_dir / f"generated_{e:03d}.png") + generate(model, save_dir / f"generated_{e:03d}.png") - vae.save_weights(str(save_dir / "weights.npz")) + model.save_weights(str(save_dir / "weights.npz")) if __name__ == "__main__": diff --git a/cvae/requirements.txt b/cvae/requirements.txt index 0fb1d31e..2a7bc1cf 100644 --- a/cvae/requirements.txt +++ b/cvae/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.0.9 +mlx>=0.2 mlx-data numpy Pillow diff --git a/cvae/model.py b/cvae/vae.py similarity index 100% rename from cvae/model.py rename to cvae/vae.py diff --git a/gcn/main.py b/gcn/main.py index 7d041b66..531e501a 100644 --- a/gcn/main.py +++ b/gcn/main.py @@ -1,4 +1,6 @@ +import time from argparse import ArgumentParser +from functools import partial import mlx.core as mx import mlx.nn as nn @@ -47,23 +49,31 @@ def main(args): mx.eval(gcn.parameters()) optimizer = optim.Adam(learning_rate=args.lr) - loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn) + + state = [gcn.state, optimizer.state, mx.random.state] + + @partial(mx.compile, inputs=state, outputs=state) + def step(): + loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn) + (loss, y_hat), grads = loss_and_grad_fn( + gcn, x, adj, y, train_mask, args.weight_decay + ) + optimizer.update(gcn, grads) + return loss, y_hat best_val_loss = float("inf") cnt = 0 # Training loop for epoch in range(args.epochs): - # Loss - (loss, y_hat), grads = loss_and_grad_fn( - gcn, x, adj, y, train_mask, args.weight_decay - ) - optimizer.update(gcn, grads) - mx.eval(gcn.parameters(), optimizer.state) + tic = time.time() + loss, y_hat = step() + mx.eval(state) # Validation val_loss = loss_fn(y_hat[val_mask], y[val_mask]) val_acc = eval_fn(y_hat[val_mask], y[val_mask]) + toc = time.time() # Early stopping if val_loss < best_val_loss: @@ -81,6 +91,7 @@ def main(args): f"Train loss: {loss.item():.3f}", f"Val loss: {val_loss.item():.3f}", f"Val acc: {val_acc.item():.2f}", + f"Time: {1e3*(toc - tic):.3f} (ms)", ] ) ) diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index a04cc7bb..defc3e78 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,4 +1,4 @@ -mlx +mlx>=0.1 numpy transformers>=4.37.0 protobuf diff --git a/mnist/main.py b/mnist/main.py index 14352df7..5ee7c5d9 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -2,6 +2,7 @@ import argparse import time +from functools import partial import mlx.core as mx import mlx.nn as nn @@ -34,10 +35,6 @@ def loss_fn(model, X, y): return nn.losses.cross_entropy(model(X), y, reduction="mean") -def eval_fn(model, X, y): - return mx.mean(mx.argmax(model(X), axis=1) == y) - - def batch_iterate(batch_size, X, y): perm = mx.array(np.random.permutation(y.size)) for s in range(0, y.size, batch_size): @@ -65,16 +62,25 @@ def main(args): model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes) mx.eval(model.parameters()) - loss_and_grad_fn = nn.value_and_grad(model, loss_fn) optimizer = optim.SGD(learning_rate=learning_rate) + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + + @partial(mx.compile, inputs=model.state, outputs=model.state) + def step(X, y): + loss, grads = loss_and_grad_fn(model, X, y) + optimizer.update(model, grads) + return loss + + @partial(mx.compile, inputs=model.state) + def eval_fn(X, y): + return mx.mean(mx.argmax(model(X), axis=1) == y) for e in range(num_epochs): tic = time.perf_counter() for X, y in batch_iterate(batch_size, train_images, train_labels): - loss, grads = loss_and_grad_fn(model, X, y) - optimizer.update(model, grads) - mx.eval(model.parameters(), optimizer.state) - accuracy = eval_fn(model, test_images, test_labels) + step(X, y) + mx.eval(model.state) + accuracy = eval_fn(test_images, test_labels) toc = time.perf_counter() print( f"Epoch {e}: Test accuracy {accuracy.item():.3f}," diff --git a/mnist/requirements.txt b/mnist/requirements.txt index 15e44a27..d269b84d 100644 --- a/mnist/requirements.txt +++ b/mnist/requirements.txt @@ -1,2 +1,2 @@ -mlx -numpy \ No newline at end of file +mlx>=0.2 +numpy diff --git a/normalizing_flow/main.py b/normalizing_flow/main.py index 2956a098..27a2e003 100644 --- a/normalizing_flow/main.py +++ b/normalizing_flow/main.py @@ -1,5 +1,7 @@ # Copyright © 2023-2024 Apple Inc. +from functools import partial + import matplotlib.pyplot as plt import mlx.core as mx import mlx.nn as nn @@ -27,18 +29,23 @@ def main(args): def loss_fn(model, x): return -mx.mean(model(x)) - loss_and_grad_fn = nn.value_and_grad(model, loss_fn) optimizer = optim.Adam(learning_rate=args.learning_rate) + state = [model.state, optimizer.state] + + @partial(mx.compile, inputs=state, outputs=state) + def step(x): + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + loss, grads = loss_and_grad_fn(model, x) + optimizer.update(model, grads) + return loss + with trange(args.n_steps) as steps: - for step in steps: + for it in steps: idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch) - loss, grads = loss_and_grad_fn(model, mx.array(x[idx])) - - optimizer.update(model, grads) - mx.eval(model.parameters()) - - steps.set_postfix(val=loss) + loss = step(mx.array(x[idx])) + mx.eval(state) + steps.set_postfix(val=loss.item()) # Plot samples from trained flow diff --git a/normalizing_flow/requirements.txt b/normalizing_flow/requirements.txt index 5b335764..4abdf74c 100644 --- a/normalizing_flow/requirements.txt +++ b/normalizing_flow/requirements.txt @@ -1,5 +1,5 @@ -mlx +mlx>=0.2 numpy tqdm scikit-learn -matplotlib \ No newline at end of file +matplotlib diff --git a/speechcommands/main.py b/speechcommands/main.py index 3f81e40b..bc318dad 100644 --- a/speechcommands/main.py +++ b/speechcommands/main.py @@ -1,5 +1,6 @@ import argparse import time +from functools import partial import kwt import mlx.core as mx @@ -46,22 +47,30 @@ def prepare_dataset(batch_size, split, root=None): .key_transform("audio", normalize) .shuffle() .batch(batch_size) + .to_stream() + .prefetch(4, 4) ) return data_iter -def eval_fn(model, inp, tgt): - return mx.mean(mx.argmax(model(inp), axis=1) == tgt) +def eval_fn(model, x, y): + return mx.mean(mx.argmax(model(x), axis=1) == y) def train_epoch(model, train_iter, optimizer, epoch): - def train_step(model, inp, tgt): - output = model(inp) - loss = mx.mean(nn.losses.cross_entropy(output, tgt)) - acc = mx.mean(mx.argmax(output, axis=1) == tgt) + def train_step(model, x, y): + output = model(x) + loss = mx.mean(nn.losses.cross_entropy(output, y)) + acc = mx.mean(mx.argmax(output, axis=1) == y) return loss, acc - train_step_fn = nn.value_and_grad(model, train_step) + state = [model.state, optimizer.state] + + @partial(mx.compile, inputs=state, outputs=state) + def step(x, y): + (loss, acc), grads = nn.value_and_grad(model, train_step)(model, x, y) + optimizer.update(model, grads) + return loss, acc losses = [] accs = [] @@ -72,9 +81,8 @@ def train_epoch(model, train_iter, optimizer, epoch): x = mx.array(batch["audio"]) y = mx.array(batch["label"]) tic = time.perf_counter() - (loss, acc), grads = train_step_fn(model, x, y) - optimizer.update(model, grads) - mx.eval(model.parameters(), optimizer.state) + loss, acc = step(x, y) + mx.eval(state) toc = time.perf_counter() loss = loss.item() acc = acc.item() diff --git a/speechcommands/requirements.txt b/speechcommands/requirements.txt index 4e6a06dd..35839609 100644 --- a/speechcommands/requirements.txt +++ b/speechcommands/requirements.txt @@ -1,2 +1,2 @@ -mlx +mlx>=0.2 mlx-data diff --git a/transformer_lm/main.py b/transformer_lm/main.py index 6d81a63a..214518d1 100644 --- a/transformer_lm/main.py +++ b/transformer_lm/main.py @@ -2,6 +2,7 @@ import math import time +from functools import partial import datasets import mlx.core as mx @@ -37,11 +38,6 @@ class TransformerLM(nn.Module): x = self.transformer(x, mask) return self.out_proj(x) - def loss(self, x, y, reduce=True): - logits = self(x) - losses = nn.losses.cross_entropy(logits, y) - return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2)) - def to_samples(context_size, dataset): tokens = dataset.size @@ -88,31 +84,42 @@ def main(args): ) print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters") + def loss_fn(model, x, y, reduce=True): + logits = model(x) + losses = nn.losses.cross_entropy(logits, y) + return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2)) + optimizer = optim.AdamW( learning_rate=args.learning_rate, weight_decay=args.weight_decay ) - loss_and_grad_fn = nn.value_and_grad(model, model.loss) - def eval_fn(model, dataset): + def eval_fn(dataset): inputs, targets = map(mx.array, to_samples(context_size, dataset)) loss = 0 for s in range(0, targets.shape[0], batch_size): bx, by = inputs[s : s + batch_size], targets[s : s + batch_size] bx, by = map(mx.array, (bx, by)) - losses = model.loss(bx, by, reduce=False) + losses = loss(bx, by, reduce=False) loss += mx.sum(losses).item() return loss / len(targets) + state = [model.state, optimizer.state] + + @partial(mx.compile, inputs=state, outputs=state) + def step(inputs, targets): + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + loss, grads = loss_and_grad_fn(model, inputs, targets) + optimizer.update(model, grads) + return loss + train_iterator = iterate_batches(batch_size, context_size, train) losses = [] tic = time.perf_counter() for it, (inputs, targets) in zip(range(args.num_iters), train_iterator): inputs, targets = map(mx.array, (inputs, targets)) - loss, grads = loss_and_grad_fn(inputs, targets) optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate - optimizer.update(model, grads) - del grads - mx.eval(loss, model.parameters()) + loss = step(inputs, targets) + mx.eval(state) losses.append(loss.item()) if (it + 1) % steps_per_report == 0: train_loss = np.mean(losses) diff --git a/transformer_lm/requirements.txt b/transformer_lm/requirements.txt index e5d7a2b1..c6739108 100644 --- a/transformer_lm/requirements.txt +++ b/transformer_lm/requirements.txt @@ -1 +1 @@ -mlx >= 0.12 +mlx>=0.2