mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
121 lines
3.4 KiB
Python
121 lines
3.4 KiB
Python
![]() |
from argparse import ArgumentParser
|
||
|
|
||
|
import mlx.core as mx
|
||
|
import mlx.nn as nn
|
||
|
import mlx.optimizers as optim
|
||
|
from mlx.nn.losses import cross_entropy
|
||
|
|
||
|
from datasets import download_cora, load_data, train_val_test_mask
|
||
|
from gcn import GCN
|
||
|
|
||
|
|
||
|
def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
|
||
|
l = mx.mean(nn.losses.cross_entropy(y_hat, y))
|
||
|
|
||
|
if weight_decay != 0.0:
|
||
|
assert parameters != None, "Model parameters missing for L2 reg."
|
||
|
|
||
|
l2_reg = mx.zeros(
|
||
|
1,
|
||
|
)
|
||
|
for k1, v1 in parameters.items():
|
||
|
for k2, v2 in v1.items():
|
||
|
l2_reg += mx.sum(v2["weight"] ** 2)
|
||
|
return l + weight_decay * l2_reg.item()
|
||
|
return l
|
||
|
|
||
|
|
||
|
def eval_fn(x, y):
|
||
|
return mx.mean(mx.argmax(x, axis=1) == y)
|
||
|
|
||
|
|
||
|
def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
|
||
|
y_hat = gcn(x, adj)
|
||
|
loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters())
|
||
|
return loss, y_hat
|
||
|
|
||
|
|
||
|
def main(args):
|
||
|
|
||
|
# Data loading
|
||
|
download_cora()
|
||
|
|
||
|
x, y, adj = load_data(args)
|
||
|
train_mask, val_mask, test_mask = train_val_test_mask(y, args.nb_classes)
|
||
|
|
||
|
gcn = GCN(
|
||
|
x_dim=x.shape[-1],
|
||
|
h_dim=args.hidden_dim,
|
||
|
out_dim=args.nb_classes,
|
||
|
nb_layers=args.nb_layers,
|
||
|
dropout=args.dropout,
|
||
|
bias=args.bias,
|
||
|
)
|
||
|
mx.eval(gcn.parameters())
|
||
|
|
||
|
optimizer = optim.Adam(learning_rate=args.lr)
|
||
|
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
|
||
|
|
||
|
best_val_loss = float("inf")
|
||
|
cnt = 0
|
||
|
|
||
|
# Training loop
|
||
|
for epoch in range(args.epochs):
|
||
|
|
||
|
# Loss
|
||
|
(loss, y_hat), grads = loss_and_grad_fn(
|
||
|
gcn, x, adj, y, train_mask, args.weight_decay
|
||
|
)
|
||
|
optimizer.update(gcn, grads)
|
||
|
mx.eval(gcn.parameters(), optimizer.state)
|
||
|
|
||
|
# Validation
|
||
|
val_loss = loss_fn(y_hat[val_mask], y[val_mask])
|
||
|
val_acc = eval_fn(y_hat[val_mask], y[val_mask])
|
||
|
|
||
|
# Early stopping
|
||
|
if val_loss < best_val_loss:
|
||
|
best_val_loss = val_loss
|
||
|
cnt = 0
|
||
|
else:
|
||
|
cnt += 1
|
||
|
if cnt == args.patience:
|
||
|
break
|
||
|
|
||
|
print(
|
||
|
" | ".join(
|
||
|
[
|
||
|
f"Epoch: {epoch:3d}",
|
||
|
f"Train loss: {loss.item():.3f}",
|
||
|
f"Val loss: {val_loss.item():.3f}",
|
||
|
f"Val acc: {val_acc.item():.2f}",
|
||
|
]
|
||
|
)
|
||
|
)
|
||
|
|
||
|
# Test
|
||
|
test_y_hat = gcn(x, adj)
|
||
|
test_loss = loss_fn(y_hat[test_mask], y[test_mask])
|
||
|
test_acc = eval_fn(y_hat[test_mask], y[test_mask])
|
||
|
|
||
|
print(f"Test loss: {test_loss.item():.3f} | Test acc: {test_acc.item():.2f}")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
|
||
|
parser = ArgumentParser()
|
||
|
parser.add_argument("--nodes_path", type=str, default="cora/cora.content")
|
||
|
parser.add_argument("--edges_path", type=str, default="cora/cora.cites")
|
||
|
parser.add_argument("--hidden_dim", type=int, default=20)
|
||
|
parser.add_argument("--dropout", type=float, default=0.5)
|
||
|
parser.add_argument("--nb_layers", type=int, default=2)
|
||
|
parser.add_argument("--nb_classes", type=int, default=7)
|
||
|
parser.add_argument("--bias", type=bool, default=True)
|
||
|
parser.add_argument("--lr", type=float, default=0.001)
|
||
|
parser.add_argument("--weight_decay", type=float, default=0.0)
|
||
|
parser.add_argument("--patience", type=int, default=20)
|
||
|
parser.add_argument("--epochs", type=int, default=100)
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
main(args)
|