diff --git a/gcn/main.py b/gcn/main.py index 5081d10a..24f07f4a 100644 --- a/gcn/main.py +++ b/gcn/main.py @@ -4,6 +4,7 @@ import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim from mlx.nn.losses import cross_entropy +from mlx.utils import tree_flatten from datasets import download_cora, load_data, train_val_test_mask from gcn import GCN @@ -18,9 +19,9 @@ def loss_fn(y_hat, y, weight_decay=0.0, parameters=None): l2_reg = mx.zeros( 1, ) - for k1, v1 in parameters.items(): - for k2, v2 in v1.items(): - l2_reg += mx.sum(v2["weight"] ** 2) + for leaf in tree_flatten(parameters): + l2_reg += mx.sum(leaf[1] ** 2) + l2_reg = mx.sqrt(l2_reg) return l + weight_decay * l2_reg.item() return l