mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Update a few examples to use compile (#420)
* update a few examples to use compile * update mnist * add compile to vae and rename some stuff for simplicity * update reqs * use state in eval * GCN example with RNG + dropout * add a bit of prefetching
This commit is contained in:
25
gcn/main.py
25
gcn/main.py
@@ -1,4 +1,6 @@
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -47,23 +49,31 @@ def main(args):
|
||||
mx.eval(gcn.parameters())
|
||||
|
||||
optimizer = optim.Adam(learning_rate=args.lr)
|
||||
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
|
||||
|
||||
state = [gcn.state, optimizer.state, mx.random.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step():
|
||||
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
|
||||
(loss, y_hat), grads = loss_and_grad_fn(
|
||||
gcn, x, adj, y, train_mask, args.weight_decay
|
||||
)
|
||||
optimizer.update(gcn, grads)
|
||||
return loss, y_hat
|
||||
|
||||
best_val_loss = float("inf")
|
||||
cnt = 0
|
||||
|
||||
# Training loop
|
||||
for epoch in range(args.epochs):
|
||||
# Loss
|
||||
(loss, y_hat), grads = loss_and_grad_fn(
|
||||
gcn, x, adj, y, train_mask, args.weight_decay
|
||||
)
|
||||
optimizer.update(gcn, grads)
|
||||
mx.eval(gcn.parameters(), optimizer.state)
|
||||
tic = time.time()
|
||||
loss, y_hat = step()
|
||||
mx.eval(state)
|
||||
|
||||
# Validation
|
||||
val_loss = loss_fn(y_hat[val_mask], y[val_mask])
|
||||
val_acc = eval_fn(y_hat[val_mask], y[val_mask])
|
||||
toc = time.time()
|
||||
|
||||
# Early stopping
|
||||
if val_loss < best_val_loss:
|
||||
@@ -81,6 +91,7 @@ def main(args):
|
||||
f"Train loss: {loss.item():.3f}",
|
||||
f"Val loss: {val_loss.item():.3f}",
|
||||
f"Val acc: {val_acc.item():.2f}",
|
||||
f"Time: {1e3*(toc - tic):.3f} (ms)",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
Reference in New Issue
Block a user