From b95e48e1467d20de9491e319bb5f984655db08c4 Mon Sep 17 00:00:00 2001 From: Tristan Bilot Date: Mon, 11 Dec 2023 20:15:11 +0100 Subject: [PATCH] use tree_flatten within L2 regularization --- gcn/main.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/gcn/main.py b/gcn/main.py index 5081d10a..24f07f4a 100644 --- a/gcn/main.py +++ b/gcn/main.py @@ -4,6 +4,7 @@ import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim from mlx.nn.losses import cross_entropy +from mlx.utils import tree_flatten from datasets import download_cora, load_data, train_val_test_mask from gcn import GCN @@ -18,9 +19,9 @@ def loss_fn(y_hat, y, weight_decay=0.0, parameters=None): l2_reg = mx.zeros( 1, ) - for k1, v1 in parameters.items(): - for k2, v2 in v1.items(): - l2_reg += mx.sum(v2["weight"] ** 2) + for leaf in tree_flatten(parameters): + l2_reg += mx.sum(leaf[1] ** 2) + l2_reg = mx.sqrt(l2_reg) return l + weight_decay * l2_reg.item() return l