Update a few examples to use compile (#420)

* update a few examples to use compile

* update mnist

* add compile to vae and rename some stuff for simplicity

* update reqs

* use state in eval

* GCN example with RNG + dropout

* add a bit of prefetching
This commit is contained in:
Awni Hannun 2024-02-08 13:00:41 -08:00 committed by GitHub
parent da7adae5ec
commit f45a1ab83c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 164 additions and 118 deletions

View File

@ -1,14 +1,12 @@
import math import numpy as np
import mlx.core as mx
from mlx.data.datasets import load_cifar10 from mlx.data.datasets import load_cifar10
def get_cifar10(batch_size, root=None): def get_cifar10(batch_size, root=None):
tr = load_cifar10(root=root) tr = load_cifar10(root=root)
mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)) mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3)) std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
def normalize(x): def normalize(x):
x = x.astype("float32") / 255.0 x = x.astype("float32") / 255.0
@ -23,6 +21,7 @@ def get_cifar10(batch_size, root=None):
.image_random_crop("image", 32, 32) .image_random_crop("image", 32, 32)
.key_transform("image", normalize) .key_transform("image", normalize)
.batch(batch_size) .batch(batch_size)
.prefetch(4, 4)
) )
test = load_cifar10(root=root, train=False) test = load_cifar10(root=root, train=False)

View File

@ -1,5 +1,6 @@
import argparse import argparse
import time import time
from functools import partial
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -33,19 +34,25 @@ def train_epoch(model, train_iter, optimizer, epoch):
acc = mx.mean(mx.argmax(output, axis=1) == tgt) acc = mx.mean(mx.argmax(output, axis=1) == tgt)
return loss, acc return loss, acc
train_step_fn = nn.value_and_grad(model, train_step)
losses = [] losses = []
accs = [] accs = []
samples_per_sec = [] samples_per_sec = []
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(inp, tgt):
train_step_fn = nn.value_and_grad(model, train_step)
(loss, acc), grads = train_step_fn(model, inp, tgt)
optimizer.update(model, grads)
return loss, acc
for batch_counter, batch in enumerate(train_iter): for batch_counter, batch in enumerate(train_iter):
x = mx.array(batch["image"]) x = mx.array(batch["image"])
y = mx.array(batch["label"]) y = mx.array(batch["label"])
tic = time.perf_counter() tic = time.perf_counter()
(loss, acc), grads = train_step_fn(model, x, y) loss, acc = step(x, y)
optimizer.update(model, grads) mx.eval(state)
mx.eval(model.parameters(), optimizer.state)
toc = time.perf_counter() toc = time.perf_counter()
loss = loss.item() loss = loss.item()
acc = acc.item() acc = acc.item()

View File

@ -1,3 +1,3 @@
mlx>=0.0.9 mlx>=0.2
mlx-data mlx-data
numpy numpy

View File

@ -23,6 +23,7 @@ def mnist(batch_size, img_size, root=None):
.image_resize("image", h=img_size[0], w=img_size[1]) .image_resize("image", h=img_size[0], w=img_size[1])
.key_transform("image", normalize) .key_transform("image", normalize)
.batch(batch_size) .batch(batch_size)
.prefetch(4, 4)
) )
# iterator over test set # iterator over test set

View File

@ -2,14 +2,15 @@
import argparse import argparse
import time import time
from functools import partial
from pathlib import Path from pathlib import Path
import dataset import dataset
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as optim import mlx.optimizers as optim
import model
import numpy as np import numpy as np
import vae
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from PIL import Image from PIL import Image
@ -53,44 +54,6 @@ def loss_fn(model, X):
return recon_loss + kl_div return recon_loss + kl_div
def train_epoch(model, data, optimizer, epoch):
loss_acc = 0.0
throughput_acc = 0.0
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Iterate over training batches
for batch_count, batch in enumerate(data):
X = mx.array(batch["image"])
throughput_tic = time.perf_counter()
# Forward pass + backward pass + update
loss, grads = loss_and_grad_fn(model, X)
optimizer.update(model, grads)
# Evaluate updated model parameters
mx.eval(model.parameters(), optimizer.state)
throughput_toc = time.perf_counter()
throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
loss_acc += loss.item()
if batch_count > 0 and (batch_count % 10 == 0):
print(
" | ".join(
[
f"Epoch {epoch:4d}",
f"Loss {(loss_acc / batch_count):10.2f}",
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
f"Batch {batch_count:5d}",
]
),
end="\r",
)
return loss_acc, throughput_acc, batch_count
def reconstruct(model, batch, out_file): def reconstruct(model, batch, out_file):
# Reconstruct a single batch only # Reconstruct a single batch only
images = mx.array(batch["image"]) images = mx.array(batch["image"])
@ -127,10 +90,10 @@ def main(args):
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
# Load the model # Load the model
vae = model.CVAE(args.latent_dims, img_size, args.max_filters) model = vae.CVAE(args.latent_dims, img_size, args.max_filters)
mx.eval(vae.parameters()) mx.eval(model.parameters())
num_params = sum(x.size for _, x in tree_flatten(vae.trainable_parameters())) num_params = sum(x.size for _, x in tree_flatten(model.trainable_parameters()))
print("Number of trainable params: {:0.04f} M".format(num_params / 1e6)) print("Number of trainable params: {:0.04f} M".format(num_params / 1e6))
optimizer = optim.AdamW(learning_rate=args.lr) optimizer = optim.AdamW(learning_rate=args.lr)
@ -139,19 +102,53 @@ def main(args):
train_batch = next(train_iter) train_batch = next(train_iter)
test_batch = next(test_iter) test_batch = next(test_iter)
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(X):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, X)
optimizer.update(model, grads)
return loss
for e in range(1, args.epochs + 1): for e in range(1, args.epochs + 1):
# Reset iterators and stats at the beginning of each epoch # Reset iterators and stats at the beginning of each epoch
train_iter.reset() train_iter.reset()
vae.train() model.train()
# Train one epoch # Train one epoch
tic = time.perf_counter() tic = time.perf_counter()
loss_acc, throughput_acc, batch_count = train_epoch( loss_acc = 0.0
vae, train_iter, optimizer, e throughput_acc = 0.0
)
toc = time.perf_counter()
vae.eval() # Iterate over training batches
for batch_count, batch in enumerate(train_iter):
X = mx.array(batch["image"])
throughput_tic = time.perf_counter()
# Forward pass + backward pass + update
loss = step(X)
# Evaluate updated model parameters
mx.eval(state)
throughput_toc = time.perf_counter()
throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
loss_acc += loss.item()
if batch_count > 0 and (batch_count % 10 == 0):
print(
" | ".join(
[
f"Epoch {e:4d}",
f"Loss {(loss_acc / batch_count):10.2f}",
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
f"Batch {batch_count:5d}",
]
),
end="\r",
)
toc = time.perf_counter()
print( print(
" | ".join( " | ".join(
@ -163,14 +160,17 @@ def main(args):
] ]
) )
) )
model.eval()
# Reconstruct a batch of training and test images # Reconstruct a batch of training and test images
reconstruct(vae, train_batch, save_dir / f"train_{e:03d}.png") reconstruct(model, train_batch, save_dir / f"train_{e:03d}.png")
reconstruct(vae, test_batch, save_dir / f"test_{e:03d}.png") reconstruct(model, test_batch, save_dir / f"test_{e:03d}.png")
# Generate images # Generate images
generate(vae, save_dir / f"generated_{e:03d}.png") generate(model, save_dir / f"generated_{e:03d}.png")
vae.save_weights(str(save_dir / "weights.npz")) model.save_weights(str(save_dir / "weights.npz"))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,4 @@
mlx>=0.0.9 mlx>=0.2
mlx-data mlx-data
numpy numpy
Pillow Pillow

View File

@ -1,4 +1,6 @@
import time
from argparse import ArgumentParser from argparse import ArgumentParser
from functools import partial
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -47,23 +49,31 @@ def main(args):
mx.eval(gcn.parameters()) mx.eval(gcn.parameters())
optimizer = optim.Adam(learning_rate=args.lr) optimizer = optim.Adam(learning_rate=args.lr)
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
state = [gcn.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state)
def step():
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
(loss, y_hat), grads = loss_and_grad_fn(
gcn, x, adj, y, train_mask, args.weight_decay
)
optimizer.update(gcn, grads)
return loss, y_hat
best_val_loss = float("inf") best_val_loss = float("inf")
cnt = 0 cnt = 0
# Training loop # Training loop
for epoch in range(args.epochs): for epoch in range(args.epochs):
# Loss tic = time.time()
(loss, y_hat), grads = loss_and_grad_fn( loss, y_hat = step()
gcn, x, adj, y, train_mask, args.weight_decay mx.eval(state)
)
optimizer.update(gcn, grads)
mx.eval(gcn.parameters(), optimizer.state)
# Validation # Validation
val_loss = loss_fn(y_hat[val_mask], y[val_mask]) val_loss = loss_fn(y_hat[val_mask], y[val_mask])
val_acc = eval_fn(y_hat[val_mask], y[val_mask]) val_acc = eval_fn(y_hat[val_mask], y[val_mask])
toc = time.time()
# Early stopping # Early stopping
if val_loss < best_val_loss: if val_loss < best_val_loss:
@ -81,6 +91,7 @@ def main(args):
f"Train loss: {loss.item():.3f}", f"Train loss: {loss.item():.3f}",
f"Val loss: {val_loss.item():.3f}", f"Val loss: {val_loss.item():.3f}",
f"Val acc: {val_acc.item():.2f}", f"Val acc: {val_acc.item():.2f}",
f"Time: {1e3*(toc - tic):.3f} (ms)",
] ]
) )
) )

View File

@ -1,4 +1,4 @@
mlx mlx>=0.1
numpy numpy
transformers>=4.37.0 transformers>=4.37.0
protobuf protobuf

View File

@ -2,6 +2,7 @@
import argparse import argparse
import time import time
from functools import partial
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -34,10 +35,6 @@ def loss_fn(model, X, y):
return nn.losses.cross_entropy(model(X), y, reduction="mean") return nn.losses.cross_entropy(model(X), y, reduction="mean")
def eval_fn(model, X, y):
return mx.mean(mx.argmax(model(X), axis=1) == y)
def batch_iterate(batch_size, X, y): def batch_iterate(batch_size, X, y):
perm = mx.array(np.random.permutation(y.size)) perm = mx.array(np.random.permutation(y.size))
for s in range(0, y.size, batch_size): for s in range(0, y.size, batch_size):
@ -65,16 +62,25 @@ def main(args):
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes) model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
mx.eval(model.parameters()) mx.eval(model.parameters())
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.SGD(learning_rate=learning_rate) optimizer = optim.SGD(learning_rate=learning_rate)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
@partial(mx.compile, inputs=model.state, outputs=model.state)
def step(X, y):
loss, grads = loss_and_grad_fn(model, X, y)
optimizer.update(model, grads)
return loss
@partial(mx.compile, inputs=model.state)
def eval_fn(X, y):
return mx.mean(mx.argmax(model(X), axis=1) == y)
for e in range(num_epochs): for e in range(num_epochs):
tic = time.perf_counter() tic = time.perf_counter()
for X, y in batch_iterate(batch_size, train_images, train_labels): for X, y in batch_iterate(batch_size, train_images, train_labels):
loss, grads = loss_and_grad_fn(model, X, y) step(X, y)
optimizer.update(model, grads) mx.eval(model.state)
mx.eval(model.parameters(), optimizer.state) accuracy = eval_fn(test_images, test_labels)
accuracy = eval_fn(model, test_images, test_labels)
toc = time.perf_counter() toc = time.perf_counter()
print( print(
f"Epoch {e}: Test accuracy {accuracy.item():.3f}," f"Epoch {e}: Test accuracy {accuracy.item():.3f},"

View File

@ -1,2 +1,2 @@
mlx mlx>=0.2
numpy numpy

View File

@ -1,5 +1,7 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
from functools import partial
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -27,18 +29,23 @@ def main(args):
def loss_fn(model, x): def loss_fn(model, x):
return -mx.mean(model(x)) return -mx.mean(model(x))
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.Adam(learning_rate=args.learning_rate) optimizer = optim.Adam(learning_rate=args.learning_rate)
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, x)
optimizer.update(model, grads)
return loss
with trange(args.n_steps) as steps: with trange(args.n_steps) as steps:
for step in steps: for it in steps:
idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch) idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch)
loss, grads = loss_and_grad_fn(model, mx.array(x[idx])) loss = step(mx.array(x[idx]))
mx.eval(state)
optimizer.update(model, grads) steps.set_postfix(val=loss.item())
mx.eval(model.parameters())
steps.set_postfix(val=loss)
# Plot samples from trained flow # Plot samples from trained flow

View File

@ -1,4 +1,4 @@
mlx mlx>=0.2
numpy numpy
tqdm tqdm
scikit-learn scikit-learn

View File

@ -1,5 +1,6 @@
import argparse import argparse
import time import time
from functools import partial
import kwt import kwt
import mlx.core as mx import mlx.core as mx
@ -46,22 +47,30 @@ def prepare_dataset(batch_size, split, root=None):
.key_transform("audio", normalize) .key_transform("audio", normalize)
.shuffle() .shuffle()
.batch(batch_size) .batch(batch_size)
.to_stream()
.prefetch(4, 4)
) )
return data_iter return data_iter
def eval_fn(model, inp, tgt): def eval_fn(model, x, y):
return mx.mean(mx.argmax(model(inp), axis=1) == tgt) return mx.mean(mx.argmax(model(x), axis=1) == y)
def train_epoch(model, train_iter, optimizer, epoch): def train_epoch(model, train_iter, optimizer, epoch):
def train_step(model, inp, tgt): def train_step(model, x, y):
output = model(inp) output = model(x)
loss = mx.mean(nn.losses.cross_entropy(output, tgt)) loss = mx.mean(nn.losses.cross_entropy(output, y))
acc = mx.mean(mx.argmax(output, axis=1) == tgt) acc = mx.mean(mx.argmax(output, axis=1) == y)
return loss, acc return loss, acc
train_step_fn = nn.value_and_grad(model, train_step) state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
(loss, acc), grads = nn.value_and_grad(model, train_step)(model, x, y)
optimizer.update(model, grads)
return loss, acc
losses = [] losses = []
accs = [] accs = []
@ -72,9 +81,8 @@ def train_epoch(model, train_iter, optimizer, epoch):
x = mx.array(batch["audio"]) x = mx.array(batch["audio"])
y = mx.array(batch["label"]) y = mx.array(batch["label"])
tic = time.perf_counter() tic = time.perf_counter()
(loss, acc), grads = train_step_fn(model, x, y) loss, acc = step(x, y)
optimizer.update(model, grads) mx.eval(state)
mx.eval(model.parameters(), optimizer.state)
toc = time.perf_counter() toc = time.perf_counter()
loss = loss.item() loss = loss.item()
acc = acc.item() acc = acc.item()

View File

@ -1,2 +1,2 @@
mlx mlx>=0.2
mlx-data mlx-data

View File

@ -2,6 +2,7 @@
import math import math
import time import time
from functools import partial
import datasets import datasets
import mlx.core as mx import mlx.core as mx
@ -37,11 +38,6 @@ class TransformerLM(nn.Module):
x = self.transformer(x, mask) x = self.transformer(x, mask)
return self.out_proj(x) return self.out_proj(x)
def loss(self, x, y, reduce=True):
logits = self(x)
losses = nn.losses.cross_entropy(logits, y)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
def to_samples(context_size, dataset): def to_samples(context_size, dataset):
tokens = dataset.size tokens = dataset.size
@ -88,31 +84,42 @@ def main(args):
) )
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters") print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
def loss_fn(model, x, y, reduce=True):
logits = model(x)
losses = nn.losses.cross_entropy(logits, y)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
optimizer = optim.AdamW( optimizer = optim.AdamW(
learning_rate=args.learning_rate, weight_decay=args.weight_decay learning_rate=args.learning_rate, weight_decay=args.weight_decay
) )
loss_and_grad_fn = nn.value_and_grad(model, model.loss)
def eval_fn(model, dataset): def eval_fn(dataset):
inputs, targets = map(mx.array, to_samples(context_size, dataset)) inputs, targets = map(mx.array, to_samples(context_size, dataset))
loss = 0 loss = 0
for s in range(0, targets.shape[0], batch_size): for s in range(0, targets.shape[0], batch_size):
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size] bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
bx, by = map(mx.array, (bx, by)) bx, by = map(mx.array, (bx, by))
losses = model.loss(bx, by, reduce=False) losses = loss(bx, by, reduce=False)
loss += mx.sum(losses).item() loss += mx.sum(losses).item()
return loss / len(targets) return loss / len(targets)
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(inputs, targets):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, inputs, targets)
optimizer.update(model, grads)
return loss
train_iterator = iterate_batches(batch_size, context_size, train) train_iterator = iterate_batches(batch_size, context_size, train)
losses = [] losses = []
tic = time.perf_counter() tic = time.perf_counter()
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator): for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
inputs, targets = map(mx.array, (inputs, targets)) inputs, targets = map(mx.array, (inputs, targets))
loss, grads = loss_and_grad_fn(inputs, targets)
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
optimizer.update(model, grads) loss = step(inputs, targets)
del grads mx.eval(state)
mx.eval(loss, model.parameters())
losses.append(loss.item()) losses.append(loss.item())
if (it + 1) % steps_per_report == 0: if (it + 1) % steps_per_report == 0:
train_loss = np.mean(losses) train_loss = np.mean(losses)

View File

@ -1 +1 @@
mlx >= 0.12 mlx>=0.2