mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Update a few examples to use compile (#420)
* update a few examples to use compile * update mnist * add compile to vae and rename some stuff for simplicity * update reqs * use state in eval * GCN example with RNG + dropout * add a bit of prefetching
This commit is contained in:
parent
da7adae5ec
commit
f45a1ab83c
@ -1,14 +1,12 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from mlx.data.datasets import load_cifar10
|
||||
|
||||
|
||||
def get_cifar10(batch_size, root=None):
|
||||
tr = load_cifar10(root=root)
|
||||
|
||||
mean = mx.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
|
||||
std = mx.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
|
||||
mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
|
||||
std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
|
||||
|
||||
def normalize(x):
|
||||
x = x.astype("float32") / 255.0
|
||||
@ -23,6 +21,7 @@ def get_cifar10(batch_size, root=None):
|
||||
.image_random_crop("image", 32, 32)
|
||||
.key_transform("image", normalize)
|
||||
.batch(batch_size)
|
||||
.prefetch(4, 4)
|
||||
)
|
||||
|
||||
test = load_cifar10(root=root, train=False)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -33,19 +34,25 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
|
||||
return loss, acc
|
||||
|
||||
train_step_fn = nn.value_and_grad(model, train_step)
|
||||
|
||||
losses = []
|
||||
accs = []
|
||||
samples_per_sec = []
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(inp, tgt):
|
||||
train_step_fn = nn.value_and_grad(model, train_step)
|
||||
(loss, acc), grads = train_step_fn(model, inp, tgt)
|
||||
optimizer.update(model, grads)
|
||||
return loss, acc
|
||||
|
||||
for batch_counter, batch in enumerate(train_iter):
|
||||
x = mx.array(batch["image"])
|
||||
y = mx.array(batch["label"])
|
||||
tic = time.perf_counter()
|
||||
(loss, acc), grads = train_step_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
loss, acc = step(x, y)
|
||||
mx.eval(state)
|
||||
toc = time.perf_counter()
|
||||
loss = loss.item()
|
||||
acc = acc.item()
|
||||
|
@ -1,3 +1,3 @@
|
||||
mlx>=0.0.9
|
||||
mlx>=0.2
|
||||
mlx-data
|
||||
numpy
|
@ -23,6 +23,7 @@ def mnist(batch_size, img_size, root=None):
|
||||
.image_resize("image", h=img_size[0], w=img_size[1])
|
||||
.key_transform("image", normalize)
|
||||
.batch(batch_size)
|
||||
.prefetch(4, 4)
|
||||
)
|
||||
|
||||
# iterator over test set
|
||||
|
104
cvae/main.py
104
cvae/main.py
@ -2,14 +2,15 @@
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import dataset
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import model
|
||||
import numpy as np
|
||||
import vae
|
||||
from mlx.utils import tree_flatten
|
||||
from PIL import Image
|
||||
|
||||
@ -53,44 +54,6 @@ def loss_fn(model, X):
|
||||
return recon_loss + kl_div
|
||||
|
||||
|
||||
def train_epoch(model, data, optimizer, epoch):
|
||||
loss_acc = 0.0
|
||||
throughput_acc = 0.0
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
|
||||
# Iterate over training batches
|
||||
for batch_count, batch in enumerate(data):
|
||||
X = mx.array(batch["image"])
|
||||
|
||||
throughput_tic = time.perf_counter()
|
||||
|
||||
# Forward pass + backward pass + update
|
||||
loss, grads = loss_and_grad_fn(model, X)
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Evaluate updated model parameters
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
throughput_toc = time.perf_counter()
|
||||
throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
|
||||
loss_acc += loss.item()
|
||||
|
||||
if batch_count > 0 and (batch_count % 10 == 0):
|
||||
print(
|
||||
" | ".join(
|
||||
[
|
||||
f"Epoch {epoch:4d}",
|
||||
f"Loss {(loss_acc / batch_count):10.2f}",
|
||||
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
|
||||
f"Batch {batch_count:5d}",
|
||||
]
|
||||
),
|
||||
end="\r",
|
||||
)
|
||||
|
||||
return loss_acc, throughput_acc, batch_count
|
||||
|
||||
|
||||
def reconstruct(model, batch, out_file):
|
||||
# Reconstruct a single batch only
|
||||
images = mx.array(batch["image"])
|
||||
@ -127,10 +90,10 @@ def main(args):
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load the model
|
||||
vae = model.CVAE(args.latent_dims, img_size, args.max_filters)
|
||||
mx.eval(vae.parameters())
|
||||
model = vae.CVAE(args.latent_dims, img_size, args.max_filters)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
num_params = sum(x.size for _, x in tree_flatten(vae.trainable_parameters()))
|
||||
num_params = sum(x.size for _, x in tree_flatten(model.trainable_parameters()))
|
||||
print("Number of trainable params: {:0.04f} M".format(num_params / 1e6))
|
||||
|
||||
optimizer = optim.AdamW(learning_rate=args.lr)
|
||||
@ -139,19 +102,53 @@ def main(args):
|
||||
train_batch = next(train_iter)
|
||||
test_batch = next(test_iter)
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(X):
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
loss, grads = loss_and_grad_fn(model, X)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
for e in range(1, args.epochs + 1):
|
||||
# Reset iterators and stats at the beginning of each epoch
|
||||
train_iter.reset()
|
||||
vae.train()
|
||||
model.train()
|
||||
|
||||
# Train one epoch
|
||||
tic = time.perf_counter()
|
||||
loss_acc, throughput_acc, batch_count = train_epoch(
|
||||
vae, train_iter, optimizer, e
|
||||
)
|
||||
toc = time.perf_counter()
|
||||
loss_acc = 0.0
|
||||
throughput_acc = 0.0
|
||||
|
||||
vae.eval()
|
||||
# Iterate over training batches
|
||||
for batch_count, batch in enumerate(train_iter):
|
||||
X = mx.array(batch["image"])
|
||||
throughput_tic = time.perf_counter()
|
||||
|
||||
# Forward pass + backward pass + update
|
||||
loss = step(X)
|
||||
|
||||
# Evaluate updated model parameters
|
||||
mx.eval(state)
|
||||
|
||||
throughput_toc = time.perf_counter()
|
||||
throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
|
||||
loss_acc += loss.item()
|
||||
|
||||
if batch_count > 0 and (batch_count % 10 == 0):
|
||||
print(
|
||||
" | ".join(
|
||||
[
|
||||
f"Epoch {e:4d}",
|
||||
f"Loss {(loss_acc / batch_count):10.2f}",
|
||||
f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
|
||||
f"Batch {batch_count:5d}",
|
||||
]
|
||||
),
|
||||
end="\r",
|
||||
)
|
||||
toc = time.perf_counter()
|
||||
|
||||
print(
|
||||
" | ".join(
|
||||
@ -163,14 +160,17 @@ def main(args):
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
model.eval()
|
||||
|
||||
# Reconstruct a batch of training and test images
|
||||
reconstruct(vae, train_batch, save_dir / f"train_{e:03d}.png")
|
||||
reconstruct(vae, test_batch, save_dir / f"test_{e:03d}.png")
|
||||
reconstruct(model, train_batch, save_dir / f"train_{e:03d}.png")
|
||||
reconstruct(model, test_batch, save_dir / f"test_{e:03d}.png")
|
||||
|
||||
# Generate images
|
||||
generate(vae, save_dir / f"generated_{e:03d}.png")
|
||||
generate(model, save_dir / f"generated_{e:03d}.png")
|
||||
|
||||
vae.save_weights(str(save_dir / "weights.npz"))
|
||||
model.save_weights(str(save_dir / "weights.npz"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,4 +1,4 @@
|
||||
mlx>=0.0.9
|
||||
mlx>=0.2
|
||||
mlx-data
|
||||
numpy
|
||||
Pillow
|
||||
|
25
gcn/main.py
25
gcn/main.py
@ -1,4 +1,6 @@
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -47,23 +49,31 @@ def main(args):
|
||||
mx.eval(gcn.parameters())
|
||||
|
||||
optimizer = optim.Adam(learning_rate=args.lr)
|
||||
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
|
||||
|
||||
state = [gcn.state, optimizer.state, mx.random.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step():
|
||||
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
|
||||
(loss, y_hat), grads = loss_and_grad_fn(
|
||||
gcn, x, adj, y, train_mask, args.weight_decay
|
||||
)
|
||||
optimizer.update(gcn, grads)
|
||||
return loss, y_hat
|
||||
|
||||
best_val_loss = float("inf")
|
||||
cnt = 0
|
||||
|
||||
# Training loop
|
||||
for epoch in range(args.epochs):
|
||||
# Loss
|
||||
(loss, y_hat), grads = loss_and_grad_fn(
|
||||
gcn, x, adj, y, train_mask, args.weight_decay
|
||||
)
|
||||
optimizer.update(gcn, grads)
|
||||
mx.eval(gcn.parameters(), optimizer.state)
|
||||
tic = time.time()
|
||||
loss, y_hat = step()
|
||||
mx.eval(state)
|
||||
|
||||
# Validation
|
||||
val_loss = loss_fn(y_hat[val_mask], y[val_mask])
|
||||
val_acc = eval_fn(y_hat[val_mask], y[val_mask])
|
||||
toc = time.time()
|
||||
|
||||
# Early stopping
|
||||
if val_loss < best_val_loss:
|
||||
@ -81,6 +91,7 @@ def main(args):
|
||||
f"Train loss: {loss.item():.3f}",
|
||||
f"Val loss: {val_loss.item():.3f}",
|
||||
f"Val acc: {val_acc.item():.2f}",
|
||||
f"Time: {1e3*(toc - tic):.3f} (ms)",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
mlx
|
||||
mlx>=0.1
|
||||
numpy
|
||||
transformers>=4.37.0
|
||||
protobuf
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -34,10 +35,6 @@ def loss_fn(model, X, y):
|
||||
return nn.losses.cross_entropy(model(X), y, reduction="mean")
|
||||
|
||||
|
||||
def eval_fn(model, X, y):
|
||||
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
||||
|
||||
|
||||
def batch_iterate(batch_size, X, y):
|
||||
perm = mx.array(np.random.permutation(y.size))
|
||||
for s in range(0, y.size, batch_size):
|
||||
@ -65,16 +62,25 @@ def main(args):
|
||||
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
optimizer = optim.SGD(learning_rate=learning_rate)
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
|
||||
@partial(mx.compile, inputs=model.state, outputs=model.state)
|
||||
def step(X, y):
|
||||
loss, grads = loss_and_grad_fn(model, X, y)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
@partial(mx.compile, inputs=model.state)
|
||||
def eval_fn(X, y):
|
||||
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
||||
|
||||
for e in range(num_epochs):
|
||||
tic = time.perf_counter()
|
||||
for X, y in batch_iterate(batch_size, train_images, train_labels):
|
||||
loss, grads = loss_and_grad_fn(model, X, y)
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
accuracy = eval_fn(model, test_images, test_labels)
|
||||
step(X, y)
|
||||
mx.eval(model.state)
|
||||
accuracy = eval_fn(test_images, test_labels)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
|
||||
|
@ -1,2 +1,2 @@
|
||||
mlx
|
||||
mlx>=0.2
|
||||
numpy
|
@ -1,5 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -27,18 +29,23 @@ def main(args):
|
||||
def loss_fn(model, x):
|
||||
return -mx.mean(model(x))
|
||||
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
optimizer = optim.Adam(learning_rate=args.learning_rate)
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(x):
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
loss, grads = loss_and_grad_fn(model, x)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
with trange(args.n_steps) as steps:
|
||||
for step in steps:
|
||||
for it in steps:
|
||||
idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch)
|
||||
loss, grads = loss_and_grad_fn(model, mx.array(x[idx]))
|
||||
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
steps.set_postfix(val=loss)
|
||||
loss = step(mx.array(x[idx]))
|
||||
mx.eval(state)
|
||||
steps.set_postfix(val=loss.item())
|
||||
|
||||
# Plot samples from trained flow
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
mlx
|
||||
mlx>=0.2
|
||||
numpy
|
||||
tqdm
|
||||
scikit-learn
|
||||
|
@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import kwt
|
||||
import mlx.core as mx
|
||||
@ -46,22 +47,30 @@ def prepare_dataset(batch_size, split, root=None):
|
||||
.key_transform("audio", normalize)
|
||||
.shuffle()
|
||||
.batch(batch_size)
|
||||
.to_stream()
|
||||
.prefetch(4, 4)
|
||||
)
|
||||
return data_iter
|
||||
|
||||
|
||||
def eval_fn(model, inp, tgt):
|
||||
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
|
||||
def eval_fn(model, x, y):
|
||||
return mx.mean(mx.argmax(model(x), axis=1) == y)
|
||||
|
||||
|
||||
def train_epoch(model, train_iter, optimizer, epoch):
|
||||
def train_step(model, inp, tgt):
|
||||
output = model(inp)
|
||||
loss = mx.mean(nn.losses.cross_entropy(output, tgt))
|
||||
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
|
||||
def train_step(model, x, y):
|
||||
output = model(x)
|
||||
loss = mx.mean(nn.losses.cross_entropy(output, y))
|
||||
acc = mx.mean(mx.argmax(output, axis=1) == y)
|
||||
return loss, acc
|
||||
|
||||
train_step_fn = nn.value_and_grad(model, train_step)
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(x, y):
|
||||
(loss, acc), grads = nn.value_and_grad(model, train_step)(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
return loss, acc
|
||||
|
||||
losses = []
|
||||
accs = []
|
||||
@ -72,9 +81,8 @@ def train_epoch(model, train_iter, optimizer, epoch):
|
||||
x = mx.array(batch["audio"])
|
||||
y = mx.array(batch["label"])
|
||||
tic = time.perf_counter()
|
||||
(loss, acc), grads = train_step_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
loss, acc = step(x, y)
|
||||
mx.eval(state)
|
||||
toc = time.perf_counter()
|
||||
loss = loss.item()
|
||||
acc = acc.item()
|
||||
|
@ -1,2 +1,2 @@
|
||||
mlx
|
||||
mlx>=0.2
|
||||
mlx-data
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import math
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import datasets
|
||||
import mlx.core as mx
|
||||
@ -37,11 +38,6 @@ class TransformerLM(nn.Module):
|
||||
x = self.transformer(x, mask)
|
||||
return self.out_proj(x)
|
||||
|
||||
def loss(self, x, y, reduce=True):
|
||||
logits = self(x)
|
||||
losses = nn.losses.cross_entropy(logits, y)
|
||||
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
|
||||
|
||||
|
||||
def to_samples(context_size, dataset):
|
||||
tokens = dataset.size
|
||||
@ -88,31 +84,42 @@ def main(args):
|
||||
)
|
||||
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
||||
|
||||
def loss_fn(model, x, y, reduce=True):
|
||||
logits = model(x)
|
||||
losses = nn.losses.cross_entropy(logits, y)
|
||||
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
|
||||
|
||||
optimizer = optim.AdamW(
|
||||
learning_rate=args.learning_rate, weight_decay=args.weight_decay
|
||||
)
|
||||
loss_and_grad_fn = nn.value_and_grad(model, model.loss)
|
||||
|
||||
def eval_fn(model, dataset):
|
||||
def eval_fn(dataset):
|
||||
inputs, targets = map(mx.array, to_samples(context_size, dataset))
|
||||
loss = 0
|
||||
for s in range(0, targets.shape[0], batch_size):
|
||||
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
||||
bx, by = map(mx.array, (bx, by))
|
||||
losses = model.loss(bx, by, reduce=False)
|
||||
losses = loss(bx, by, reduce=False)
|
||||
loss += mx.sum(losses).item()
|
||||
return loss / len(targets)
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(inputs, targets):
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
loss, grads = loss_and_grad_fn(model, inputs, targets)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
train_iterator = iterate_batches(batch_size, context_size, train)
|
||||
losses = []
|
||||
tic = time.perf_counter()
|
||||
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
|
||||
inputs, targets = map(mx.array, (inputs, targets))
|
||||
loss, grads = loss_and_grad_fn(inputs, targets)
|
||||
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
|
||||
optimizer.update(model, grads)
|
||||
del grads
|
||||
mx.eval(loss, model.parameters())
|
||||
loss = step(inputs, targets)
|
||||
mx.eval(state)
|
||||
losses.append(loss.item())
|
||||
if (it + 1) % steps_per_report == 0:
|
||||
train_loss = np.mean(losses)
|
||||
|
@ -1 +1 @@
|
||||
mlx >= 0.12
|
||||
mlx>=0.2
|
||||
|
Loading…
Reference in New Issue
Block a user