use tree_flatten within L2 regularization

This commit is contained in:
Tristan Bilot 2023-12-11 20:15:11 +01:00
parent ed5a830626
commit b95e48e146

View File

@ -4,6 +4,7 @@ import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.nn.losses import cross_entropy
from mlx.utils import tree_flatten
from datasets import download_cora, load_data, train_val_test_mask
from gcn import GCN
@ -18,9 +19,9 @@ def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
l2_reg = mx.zeros(
1,
)
for k1, v1 in parameters.items():
for k2, v2 in v1.items():
l2_reg += mx.sum(v2["weight"] ** 2)
for leaf in tree_flatten(parameters):
l2_reg += mx.sum(leaf[1] ** 2)
l2_reg = mx.sqrt(l2_reg)
return l + weight_decay * l2_reg.item()
return l