mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-11 14:24:35 +08:00
Update a few examples to use compile (#420)
* update a few examples to use compile * update mnist * add compile to vae and rename some stuff for simplicity * update reqs * use state in eval * GCN example with RNG + dropout * add a bit of prefetching
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -27,18 +29,23 @@ def main(args):
|
||||
def loss_fn(model, x):
|
||||
return -mx.mean(model(x))
|
||||
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
optimizer = optim.Adam(learning_rate=args.learning_rate)
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(x):
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
loss, grads = loss_and_grad_fn(model, x)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
with trange(args.n_steps) as steps:
|
||||
for step in steps:
|
||||
for it in steps:
|
||||
idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch)
|
||||
loss, grads = loss_and_grad_fn(model, mx.array(x[idx]))
|
||||
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
steps.set_postfix(val=loss)
|
||||
loss = step(mx.array(x[idx]))
|
||||
mx.eval(state)
|
||||
steps.set_postfix(val=loss.item())
|
||||
|
||||
# Plot samples from trained flow
|
||||
|
||||
|
Reference in New Issue
Block a user