mlx-examples/gcn/main.py
Awni Hannun f45a1ab83c
Update a few examples to use compile (#420)
* update a few examples to use compile

* update mnist

* add compile to vae and rename some stuff for simplicity

* update reqs

* use state in eval

* GCN example with RNG + dropout

* add a bit of prefetching
2024-02-08 13:00:41 -08:00

123 lines
3.6 KiB
Python

import time
from argparse import ArgumentParser
from functools import partial
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from datasets import download_cora, load_data, train_val_test_mask
from mlx.nn.losses import cross_entropy
from mlx.utils import tree_flatten
from gcn import GCN
def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
l = mx.mean(nn.losses.cross_entropy(y_hat, y))
if weight_decay != 0.0:
assert parameters != None, "Model parameters missing for L2 reg."
l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt()
return l + weight_decay * l2_reg
return l
def eval_fn(x, y):
return mx.mean(mx.argmax(x, axis=1) == y)
def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
y_hat = gcn(x, adj)
loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters())
return loss, y_hat
def main(args):
# Data loading
x, y, adj = load_data(args)
train_mask, val_mask, test_mask = train_val_test_mask()
gcn = GCN(
x_dim=x.shape[-1],
h_dim=args.hidden_dim,
out_dim=args.nb_classes,
nb_layers=args.nb_layers,
dropout=args.dropout,
bias=args.bias,
)
mx.eval(gcn.parameters())
optimizer = optim.Adam(learning_rate=args.lr)
state = [gcn.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state)
def step():
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
(loss, y_hat), grads = loss_and_grad_fn(
gcn, x, adj, y, train_mask, args.weight_decay
)
optimizer.update(gcn, grads)
return loss, y_hat
best_val_loss = float("inf")
cnt = 0
# Training loop
for epoch in range(args.epochs):
tic = time.time()
loss, y_hat = step()
mx.eval(state)
# Validation
val_loss = loss_fn(y_hat[val_mask], y[val_mask])
val_acc = eval_fn(y_hat[val_mask], y[val_mask])
toc = time.time()
# Early stopping
if val_loss < best_val_loss:
best_val_loss = val_loss
cnt = 0
else:
cnt += 1
if cnt == args.patience:
break
print(
" | ".join(
[
f"Epoch: {epoch:3d}",
f"Train loss: {loss.item():.3f}",
f"Val loss: {val_loss.item():.3f}",
f"Val acc: {val_acc.item():.2f}",
f"Time: {1e3*(toc - tic):.3f} (ms)",
]
)
)
# Test
test_y_hat = gcn(x, adj)
test_loss = loss_fn(y_hat[test_mask], y[test_mask])
test_acc = eval_fn(y_hat[test_mask], y[test_mask])
print(f"Test loss: {test_loss.item():.3f} | Test acc: {test_acc.item():.2f}")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--nodes_path", type=str, default="cora/cora.content")
parser.add_argument("--edges_path", type=str, default="cora/cora.cites")
parser.add_argument("--hidden_dim", type=int, default=20)
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument("--nb_layers", type=int, default=2)
parser.add_argument("--nb_classes", type=int, default=7)
parser.add_argument("--bias", type=bool, default=True)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--weight_decay", type=float, default=0.0)
parser.add_argument("--patience", type=int, default=20)
parser.add_argument("--epochs", type=int, default=100)
args = parser.parse_args()
main(args)