Update a few examples to use compile (#420)

* update a few examples to use compile

* update mnist

* add compile to vae and rename some stuff for simplicity

* update reqs

* use state in eval

* GCN example with RNG + dropout

* add a bit of prefetching
This commit is contained in:
Awni Hannun 2024-02-08 13:00:41 -08:00 committed by GitHub
parent da7adae5ec
commit f45a1ab83c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 164 additions and 118 deletions

View File

@ -1,14 +1,12 @@
import math
import mlx.core as mx
import numpy as np
from mlx.data.datasets import load_cifar10
def get_cifar10(batch_size, root=None):
tr = load_cifar10(root=root)
mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
def normalize(x):
x = x.astype("float32") / 255.0
@ -23,6 +21,7 @@ def get_cifar10(batch_size, root=None):
.image_random_crop("image", 32, 32)
.key_transform("image", normalize)
.batch(batch_size)
.prefetch(4, 4)
)
test = load_cifar10(root=root, train=False)

View File

@ -1,5 +1,6 @@
import argparse
import time
from functools import partial
import mlx.core as mx
import mlx.nn as nn
@ -33,19 +34,25 @@ def train_epoch(model, train_iter, optimizer, epoch):
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
return loss, acc
train_step_fn = nn.value_and_grad(model, train_step)
losses = []
accs = []
samples_per_sec = []
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(inp, tgt):
train_step_fn = nn.value_and_grad(model, train_step)
(loss, acc), grads = train_step_fn(model, inp, tgt)
optimizer.update(model, grads)
return loss, acc
for batch_counter, batch in enumerate(train_iter):
x = mx.array(batch["image"])
y = mx.array(batch["label"])
tic = time.perf_counter()
(loss, acc), grads = train_step_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
loss, acc = step(x, y)
mx.eval(state)
toc = time.perf_counter()
loss = loss.item()
acc = acc.item()

View File

@ -1,3 +1,3 @@
mlx>=0.0.9
mlx>=0.2
mlx-data
numpy
numpy

View File

@ -23,6 +23,7 @@ def mnist(batch_size, img_size, root=None):
.image_resize("image", h=img_size[0], w=img_size[1])
.key_transform("image", normalize)
.batch(batch_size)
.prefetch(4, 4)
)
# iterator over test set

View File

@ -2,14 +2,15 @@
import argparse
import time
from functools import partial
from pathlib import Path
import dataset
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import model
import numpy as np
import vae
from mlx.utils import tree_flatten
from PIL import Image
@ -53,44 +54,6 @@ def loss_fn(model, X):
return recon_loss + kl_div
def train_epoch(model, data, optimizer, epoch):
loss_acc = 0.0
throughput_acc = 0.0
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Iterate over training batches
for batch_count, batch in enumerate(data):
X = mx.array(batch["image"])
throughput_tic = time.perf_counter()
# Forward pass + backward pass + update
loss, grads = loss_and_grad_fn(model, X)
optimizer.update(model, grads)
# Evaluate updated model parameters
mx.eval(model.parameters(), optimizer.state)
throughput_toc = time.perf_counter()
throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
loss_acc += loss.item()
if batch_count > 0 and (batch_count % 10 == 0):
print(
" | ".join(
[
f"Epoch {epoch:4d}",
f"Loss {(loss_acc / batch_count):10.2f}",
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
f"Batch {batch_count:5d}",
]
),
end="\r",
)
return loss_acc, throughput_acc, batch_count
def reconstruct(model, batch, out_file):
# Reconstruct a single batch only
images = mx.array(batch["image"])
@ -127,10 +90,10 @@ def main(args):
save_dir.mkdir(parents=True, exist_ok=True)
# Load the model
vae = model.CVAE(args.latent_dims, img_size, args.max_filters)
mx.eval(vae.parameters())
model = vae.CVAE(args.latent_dims, img_size, args.max_filters)
mx.eval(model.parameters())
num_params = sum(x.size for _, x in tree_flatten(vae.trainable_parameters()))
num_params = sum(x.size for _, x in tree_flatten(model.trainable_parameters()))
print("Number of trainable params: {:0.04f} M".format(num_params / 1e6))
optimizer = optim.AdamW(learning_rate=args.lr)
@ -139,19 +102,53 @@ def main(args):
train_batch = next(train_iter)
test_batch = next(test_iter)
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(X):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, X)
optimizer.update(model, grads)
return loss
for e in range(1, args.epochs + 1):
# Reset iterators and stats at the beginning of each epoch
train_iter.reset()
vae.train()
model.train()
# Train one epoch
tic = time.perf_counter()
loss_acc, throughput_acc, batch_count = train_epoch(
vae, train_iter, optimizer, e
)
toc = time.perf_counter()
loss_acc = 0.0
throughput_acc = 0.0
vae.eval()
# Iterate over training batches
for batch_count, batch in enumerate(train_iter):
X = mx.array(batch["image"])
throughput_tic = time.perf_counter()
# Forward pass + backward pass + update
loss = step(X)
# Evaluate updated model parameters
mx.eval(state)
throughput_toc = time.perf_counter()
throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
loss_acc += loss.item()
if batch_count > 0 and (batch_count % 10 == 0):
print(
" | ".join(
[
f"Epoch {e:4d}",
f"Loss {(loss_acc / batch_count):10.2f}",
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
f"Batch {batch_count:5d}",
]
),
end="\r",
)
toc = time.perf_counter()
print(
" | ".join(
@ -163,14 +160,17 @@ def main(args):
]
)
)
model.eval()
# Reconstruct a batch of training and test images
reconstruct(vae, train_batch, save_dir / f"train_{e:03d}.png")
reconstruct(vae, test_batch, save_dir / f"test_{e:03d}.png")
reconstruct(model, train_batch, save_dir / f"train_{e:03d}.png")
reconstruct(model, test_batch, save_dir / f"test_{e:03d}.png")
# Generate images
generate(vae, save_dir / f"generated_{e:03d}.png")
generate(model, save_dir / f"generated_{e:03d}.png")
vae.save_weights(str(save_dir / "weights.npz"))
model.save_weights(str(save_dir / "weights.npz"))
if __name__ == "__main__":

View File

@ -1,4 +1,4 @@
mlx>=0.0.9
mlx>=0.2
mlx-data
numpy
Pillow

View File

@ -1,4 +1,6 @@
import time
from argparse import ArgumentParser
from functools import partial
import mlx.core as mx
import mlx.nn as nn
@ -47,23 +49,31 @@ def main(args):
mx.eval(gcn.parameters())
optimizer = optim.Adam(learning_rate=args.lr)
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
state = [gcn.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state)
def step():
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
(loss, y_hat), grads = loss_and_grad_fn(
gcn, x, adj, y, train_mask, args.weight_decay
)
optimizer.update(gcn, grads)
return loss, y_hat
best_val_loss = float("inf")
cnt = 0
# Training loop
for epoch in range(args.epochs):
# Loss
(loss, y_hat), grads = loss_and_grad_fn(
gcn, x, adj, y, train_mask, args.weight_decay
)
optimizer.update(gcn, grads)
mx.eval(gcn.parameters(), optimizer.state)
tic = time.time()
loss, y_hat = step()
mx.eval(state)
# Validation
val_loss = loss_fn(y_hat[val_mask], y[val_mask])
val_acc = eval_fn(y_hat[val_mask], y[val_mask])
toc = time.time()
# Early stopping
if val_loss < best_val_loss:
@ -81,6 +91,7 @@ def main(args):
f"Train loss: {loss.item():.3f}",
f"Val loss: {val_loss.item():.3f}",
f"Val acc: {val_acc.item():.2f}",
f"Time: {1e3*(toc - tic):.3f} (ms)",
]
)
)

View File

@ -1,4 +1,4 @@
mlx
mlx>=0.1
numpy
transformers>=4.37.0
protobuf

View File

@ -2,6 +2,7 @@
import argparse
import time
from functools import partial
import mlx.core as mx
import mlx.nn as nn
@ -34,10 +35,6 @@ def loss_fn(model, X, y):
return nn.losses.cross_entropy(model(X), y, reduction="mean")
def eval_fn(model, X, y):
return mx.mean(mx.argmax(model(X), axis=1) == y)
def batch_iterate(batch_size, X, y):
perm = mx.array(np.random.permutation(y.size))
for s in range(0, y.size, batch_size):
@ -65,16 +62,25 @@ def main(args):
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
mx.eval(model.parameters())
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.SGD(learning_rate=learning_rate)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
@partial(mx.compile, inputs=model.state, outputs=model.state)
def step(X, y):
loss, grads = loss_and_grad_fn(model, X, y)
optimizer.update(model, grads)
return loss
@partial(mx.compile, inputs=model.state)
def eval_fn(X, y):
return mx.mean(mx.argmax(model(X), axis=1) == y)
for e in range(num_epochs):
tic = time.perf_counter()
for X, y in batch_iterate(batch_size, train_images, train_labels):
loss, grads = loss_and_grad_fn(model, X, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
accuracy = eval_fn(model, test_images, test_labels)
step(X, y)
mx.eval(model.state)
accuracy = eval_fn(test_images, test_labels)
toc = time.perf_counter()
print(
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"

View File

@ -1,2 +1,2 @@
mlx
numpy
mlx>=0.2
numpy

View File

@ -1,5 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from functools import partial
import matplotlib.pyplot as plt
import mlx.core as mx
import mlx.nn as nn
@ -27,18 +29,23 @@ def main(args):
def loss_fn(model, x):
return -mx.mean(model(x))
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.Adam(learning_rate=args.learning_rate)
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, x)
optimizer.update(model, grads)
return loss
with trange(args.n_steps) as steps:
for step in steps:
for it in steps:
idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch)
loss, grads = loss_and_grad_fn(model, mx.array(x[idx]))
optimizer.update(model, grads)
mx.eval(model.parameters())
steps.set_postfix(val=loss)
loss = step(mx.array(x[idx]))
mx.eval(state)
steps.set_postfix(val=loss.item())
# Plot samples from trained flow

View File

@ -1,5 +1,5 @@
mlx
mlx>=0.2
numpy
tqdm
scikit-learn
matplotlib
matplotlib

View File

@ -1,5 +1,6 @@
import argparse
import time
from functools import partial
import kwt
import mlx.core as mx
@ -46,22 +47,30 @@ def prepare_dataset(batch_size, split, root=None):
.key_transform("audio", normalize)
.shuffle()
.batch(batch_size)
.to_stream()
.prefetch(4, 4)
)
return data_iter
def eval_fn(model, inp, tgt):
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
def eval_fn(model, x, y):
return mx.mean(mx.argmax(model(x), axis=1) == y)
def train_epoch(model, train_iter, optimizer, epoch):
def train_step(model, inp, tgt):
output = model(inp)
loss = mx.mean(nn.losses.cross_entropy(output, tgt))
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
def train_step(model, x, y):
output = model(x)
loss = mx.mean(nn.losses.cross_entropy(output, y))
acc = mx.mean(mx.argmax(output, axis=1) == y)
return loss, acc
train_step_fn = nn.value_and_grad(model, train_step)
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
(loss, acc), grads = nn.value_and_grad(model, train_step)(model, x, y)
optimizer.update(model, grads)
return loss, acc
losses = []
accs = []
@ -72,9 +81,8 @@ def train_epoch(model, train_iter, optimizer, epoch):
x = mx.array(batch["audio"])
y = mx.array(batch["label"])
tic = time.perf_counter()
(loss, acc), grads = train_step_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
loss, acc = step(x, y)
mx.eval(state)
toc = time.perf_counter()
loss = loss.item()
acc = acc.item()

View File

@ -1,2 +1,2 @@
mlx
mlx>=0.2
mlx-data

View File

@ -2,6 +2,7 @@
import math
import time
from functools import partial
import datasets
import mlx.core as mx
@ -37,11 +38,6 @@ class TransformerLM(nn.Module):
x = self.transformer(x, mask)
return self.out_proj(x)
def loss(self, x, y, reduce=True):
logits = self(x)
losses = nn.losses.cross_entropy(logits, y)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
def to_samples(context_size, dataset):
tokens = dataset.size
@ -88,31 +84,42 @@ def main(args):
)
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
def loss_fn(model, x, y, reduce=True):
logits = model(x)
losses = nn.losses.cross_entropy(logits, y)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
optimizer = optim.AdamW(
learning_rate=args.learning_rate, weight_decay=args.weight_decay
)
loss_and_grad_fn = nn.value_and_grad(model, model.loss)
def eval_fn(model, dataset):
def eval_fn(dataset):
inputs, targets = map(mx.array, to_samples(context_size, dataset))
loss = 0
for s in range(0, targets.shape[0], batch_size):
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
bx, by = map(mx.array, (bx, by))
losses = model.loss(bx, by, reduce=False)
losses = loss(bx, by, reduce=False)
loss += mx.sum(losses).item()
return loss / len(targets)
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(inputs, targets):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, inputs, targets)
optimizer.update(model, grads)
return loss
train_iterator = iterate_batches(batch_size, context_size, train)
losses = []
tic = time.perf_counter()
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
inputs, targets = map(mx.array, (inputs, targets))
loss, grads = loss_and_grad_fn(inputs, targets)
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
optimizer.update(model, grads)
del grads
mx.eval(loss, model.parameters())
loss = step(inputs, targets)
mx.eval(state)
losses.append(loss.item())
if (it + 1) % steps_per_report == 0:
train_loss = np.mean(losses)

View File

@ -1 +1 @@
mlx >= 0.12
mlx>=0.2