mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-10 13:07:28 +08:00
Update a few examples to use compile (#420)
* update a few examples to use compile * update mnist * add compile to vae and rename some stuff for simplicity * update reqs * use state in eval * GCN example with RNG + dropout * add a bit of prefetching
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import math
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import datasets
|
||||
import mlx.core as mx
|
||||
@@ -37,11 +38,6 @@ class TransformerLM(nn.Module):
|
||||
x = self.transformer(x, mask)
|
||||
return self.out_proj(x)
|
||||
|
||||
def loss(self, x, y, reduce=True):
|
||||
logits = self(x)
|
||||
losses = nn.losses.cross_entropy(logits, y)
|
||||
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
|
||||
|
||||
|
||||
def to_samples(context_size, dataset):
|
||||
tokens = dataset.size
|
||||
@@ -88,31 +84,42 @@ def main(args):
|
||||
)
|
||||
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
|
||||
|
||||
def loss_fn(model, x, y, reduce=True):
|
||||
logits = model(x)
|
||||
losses = nn.losses.cross_entropy(logits, y)
|
||||
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
|
||||
|
||||
optimizer = optim.AdamW(
|
||||
learning_rate=args.learning_rate, weight_decay=args.weight_decay
|
||||
)
|
||||
loss_and_grad_fn = nn.value_and_grad(model, model.loss)
|
||||
|
||||
def eval_fn(model, dataset):
|
||||
def eval_fn(dataset):
|
||||
inputs, targets = map(mx.array, to_samples(context_size, dataset))
|
||||
loss = 0
|
||||
for s in range(0, targets.shape[0], batch_size):
|
||||
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
||||
bx, by = map(mx.array, (bx, by))
|
||||
losses = model.loss(bx, by, reduce=False)
|
||||
losses = loss(bx, by, reduce=False)
|
||||
loss += mx.sum(losses).item()
|
||||
return loss / len(targets)
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(inputs, targets):
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
loss, grads = loss_and_grad_fn(model, inputs, targets)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
train_iterator = iterate_batches(batch_size, context_size, train)
|
||||
losses = []
|
||||
tic = time.perf_counter()
|
||||
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
|
||||
inputs, targets = map(mx.array, (inputs, targets))
|
||||
loss, grads = loss_and_grad_fn(inputs, targets)
|
||||
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
|
||||
optimizer.update(model, grads)
|
||||
del grads
|
||||
mx.eval(loss, model.parameters())
|
||||
loss = step(inputs, targets)
|
||||
mx.eval(state)
|
||||
losses.append(loss.item())
|
||||
if (it + 1) % steps_per_report == 0:
|
||||
train_loss = np.mean(losses)
|
||||
|
Reference in New Issue
Block a user