mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 05:58:07 +08:00
Update a few examples to use compile (#420)
* update a few examples to use compile * update mnist * add compile to vae and rename some stuff for simplicity * update reqs * use state in eval * GCN example with RNG + dropout * add a bit of prefetching
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -34,10 +35,6 @@ def loss_fn(model, X, y):
|
||||
return nn.losses.cross_entropy(model(X), y, reduction="mean")
|
||||
|
||||
|
||||
def eval_fn(model, X, y):
|
||||
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
||||
|
||||
|
||||
def batch_iterate(batch_size, X, y):
|
||||
perm = mx.array(np.random.permutation(y.size))
|
||||
for s in range(0, y.size, batch_size):
|
||||
@@ -65,16 +62,25 @@ def main(args):
|
||||
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
optimizer = optim.SGD(learning_rate=learning_rate)
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
|
||||
@partial(mx.compile, inputs=model.state, outputs=model.state)
|
||||
def step(X, y):
|
||||
loss, grads = loss_and_grad_fn(model, X, y)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
@partial(mx.compile, inputs=model.state)
|
||||
def eval_fn(X, y):
|
||||
return mx.mean(mx.argmax(model(X), axis=1) == y)
|
||||
|
||||
for e in range(num_epochs):
|
||||
tic = time.perf_counter()
|
||||
for X, y in batch_iterate(batch_size, train_images, train_labels):
|
||||
loss, grads = loss_and_grad_fn(model, X, y)
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
accuracy = eval_fn(model, test_images, test_labels)
|
||||
step(X, y)
|
||||
mx.eval(model.state)
|
||||
accuracy = eval_fn(test_images, test_labels)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Epoch {e}: Test accuracy {accuracy.item():.3f},"
|
||||
|
@@ -1,2 +1,2 @@
|
||||
mlx
|
||||
numpy
|
||||
mlx>=0.2
|
||||
numpy
|
||||
|
Reference in New Issue
Block a user