Update a few examples to use compile (#420)

* update a few examples to use compile

* update mnist

* add compile to vae and rename some stuff for simplicity

* update reqs

* use state in eval

* GCN example with RNG + dropout

* add a bit of prefetching
This commit is contained in:
Awni Hannun
2024-02-08 13:00:41 -08:00
committed by GitHub
parent da7adae5ec
commit f45a1ab83c
17 changed files with 164 additions and 118 deletions

View File

@@ -1,5 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from functools import partial
import matplotlib.pyplot as plt
import mlx.core as mx
import mlx.nn as nn
@@ -27,18 +29,23 @@ def main(args):
def loss_fn(model, x):
return -mx.mean(model(x))
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.Adam(learning_rate=args.learning_rate)
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, x)
optimizer.update(model, grads)
return loss
with trange(args.n_steps) as steps:
for step in steps:
for it in steps:
idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch)
loss, grads = loss_and_grad_fn(model, mx.array(x[idx]))
optimizer.update(model, grads)
mx.eval(model.parameters())
steps.set_postfix(val=loss)
loss = step(mx.array(x[idx]))
mx.eval(state)
steps.set_postfix(val=loss.item())
# Plot samples from trained flow