mlx-examples/gcn/main.py

122 lines
3.5 KiB
Python
Raw Normal View History

import time
2023-12-12 00:48:07 +08:00
from argparse import ArgumentParser
from functools import partial
2023-12-12 00:48:07 +08:00
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from datasets import load_data, train_val_test_mask
from mlx.utils import tree_flatten
2023-12-12 00:48:07 +08:00
from gcn import GCN
def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
l = mx.mean(nn.losses.cross_entropy(y_hat, y))
if weight_decay != 0.0:
assert parameters != None, "Model parameters missing for L2 reg."
2023-12-12 06:10:46 +08:00
l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt()
return l + weight_decay * l2_reg
2023-12-12 00:48:07 +08:00
return l
def eval_fn(x, y):
return mx.mean(mx.argmax(x, axis=1) == y)
def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
y_hat = gcn(x, adj)
loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters())
return loss, y_hat
def main(args):
# Data loading
x, y, adj = load_data(args)
2023-12-12 06:10:46 +08:00
train_mask, val_mask, test_mask = train_val_test_mask()
2023-12-12 00:48:07 +08:00
gcn = GCN(
x_dim=x.shape[-1],
h_dim=args.hidden_dim,
out_dim=args.nb_classes,
nb_layers=args.nb_layers,
dropout=args.dropout,
bias=args.bias,
)
mx.eval(gcn.parameters())
optimizer = optim.Adam(learning_rate=args.lr)
state = [gcn.state, optimizer.state, mx.random.state]
2023-12-12 00:48:07 +08:00
@partial(mx.compile, inputs=state, outputs=state)
def step():
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
2023-12-12 00:48:07 +08:00
(loss, y_hat), grads = loss_and_grad_fn(
gcn, x, adj, y, train_mask, args.weight_decay
)
optimizer.update(gcn, grads)
return loss, y_hat
best_val_loss = float("inf")
cnt = 0
# Training loop
for epoch in range(args.epochs):
tic = time.time()
loss, y_hat = step()
mx.eval(state)
2023-12-12 00:48:07 +08:00
# Validation
val_loss = loss_fn(y_hat[val_mask], y[val_mask])
val_acc = eval_fn(y_hat[val_mask], y[val_mask])
toc = time.time()
2023-12-12 00:48:07 +08:00
# Early stopping
if val_loss < best_val_loss:
best_val_loss = val_loss
cnt = 0
else:
cnt += 1
if cnt == args.patience:
break
print(
" | ".join(
[
f"Epoch: {epoch:3d}",
f"Train loss: {loss.item():.3f}",
f"Val loss: {val_loss.item():.3f}",
f"Val acc: {val_acc.item():.2f}",
f"Time: {1e3*(toc - tic):.3f} (ms)",
2023-12-12 00:48:07 +08:00
]
)
)
# Test
test_y_hat = gcn(x, adj)
test_loss = loss_fn(y_hat[test_mask], y[test_mask])
test_acc = eval_fn(y_hat[test_mask], y[test_mask])
print(f"Test loss: {test_loss.item():.3f} | Test acc: {test_acc.item():.2f}")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--nodes_path", type=str, default="cora/cora.content")
parser.add_argument("--edges_path", type=str, default="cora/cora.cites")
parser.add_argument("--hidden_dim", type=int, default=20)
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument("--nb_layers", type=int, default=2)
parser.add_argument("--nb_classes", type=int, default=7)
parser.add_argument("--bias", type=bool, default=True)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--weight_decay", type=float, default=0.0)
parser.add_argument("--patience", type=int, default=20)
parser.add_argument("--epochs", type=int, default=100)
args = parser.parse_args()
main(args)