Update a few examples to use compile (#420)

* update a few examples to use compile

* update mnist

* add compile to vae and rename some stuff for simplicity

* update reqs

* use state in eval

* GCN example with RNG + dropout

* add a bit of prefetching
This commit is contained in:
Awni Hannun
2024-02-08 13:00:41 -08:00
committed by GitHub
parent da7adae5ec
commit f45a1ab83c
17 changed files with 164 additions and 118 deletions

View File

@@ -2,6 +2,7 @@
import math
import time
from functools import partial
import datasets
import mlx.core as mx
@@ -37,11 +38,6 @@ class TransformerLM(nn.Module):
x = self.transformer(x, mask)
return self.out_proj(x)
def loss(self, x, y, reduce=True):
logits = self(x)
losses = nn.losses.cross_entropy(logits, y)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
def to_samples(context_size, dataset):
tokens = dataset.size
@@ -88,31 +84,42 @@ def main(args):
)
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
def loss_fn(model, x, y, reduce=True):
logits = model(x)
losses = nn.losses.cross_entropy(logits, y)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
optimizer = optim.AdamW(
learning_rate=args.learning_rate, weight_decay=args.weight_decay
)
loss_and_grad_fn = nn.value_and_grad(model, model.loss)
def eval_fn(model, dataset):
def eval_fn(dataset):
inputs, targets = map(mx.array, to_samples(context_size, dataset))
loss = 0
for s in range(0, targets.shape[0], batch_size):
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
bx, by = map(mx.array, (bx, by))
losses = model.loss(bx, by, reduce=False)
losses = loss(bx, by, reduce=False)
loss += mx.sum(losses).item()
return loss / len(targets)
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(inputs, targets):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, inputs, targets)
optimizer.update(model, grads)
return loss
train_iterator = iterate_batches(batch_size, context_size, train)
losses = []
tic = time.perf_counter()
for it, (inputs, targets) in zip(range(args.num_iters), train_iterator):
inputs, targets = map(mx.array, (inputs, targets))
loss, grads = loss_and_grad_fn(inputs, targets)
optimizer.learning_rate = min(1, it / args.lr_warmup) * args.learning_rate
optimizer.update(model, grads)
del grads
mx.eval(loss, model.parameters())
loss = step(inputs, targets)
mx.eval(state)
losses.append(loss.item())
if (it + 1) % steps_per_report == 0:
train_loss = np.mean(losses)