Update a few examples to use compile (#420)

* update a few examples to use compile

* update mnist

* add compile to vae and rename some stuff for simplicity

* update reqs

* use state in eval

* GCN example with RNG + dropout

* add a bit of prefetching
This commit is contained in:
Awni Hannun
2024-02-08 13:00:41 -08:00
committed by GitHub
parent da7adae5ec
commit f45a1ab83c
17 changed files with 164 additions and 118 deletions

View File

@@ -1,5 +1,6 @@
import argparse
import time
from functools import partial
import kwt
import mlx.core as mx
@@ -46,22 +47,30 @@ def prepare_dataset(batch_size, split, root=None):
.key_transform("audio", normalize)
.shuffle()
.batch(batch_size)
.to_stream()
.prefetch(4, 4)
)
return data_iter
def eval_fn(model, inp, tgt):
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
def eval_fn(model, x, y):
return mx.mean(mx.argmax(model(x), axis=1) == y)
def train_epoch(model, train_iter, optimizer, epoch):
def train_step(model, inp, tgt):
output = model(inp)
loss = mx.mean(nn.losses.cross_entropy(output, tgt))
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
def train_step(model, x, y):
output = model(x)
loss = mx.mean(nn.losses.cross_entropy(output, y))
acc = mx.mean(mx.argmax(output, axis=1) == y)
return loss, acc
train_step_fn = nn.value_and_grad(model, train_step)
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
(loss, acc), grads = nn.value_and_grad(model, train_step)(model, x, y)
optimizer.update(model, grads)
return loss, acc
losses = []
accs = []
@@ -72,9 +81,8 @@ def train_epoch(model, train_iter, optimizer, epoch):
x = mx.array(batch["audio"])
y = mx.array(batch["label"])
tic = time.perf_counter()
(loss, acc), grads = train_step_fn(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
loss, acc = step(x, y)
mx.eval(state)
toc = time.perf_counter()
loss = loss.item()
acc = acc.item()

View File

@@ -1,2 +1,2 @@
mlx
mlx>=0.2
mlx-data