mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
use tree_flatten within L2 regularization
This commit is contained in:
parent
ed5a830626
commit
b95e48e146
@ -4,6 +4,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
from mlx.nn.losses import cross_entropy
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from datasets import download_cora, load_data, train_val_test_mask
|
||||
from gcn import GCN
|
||||
@ -18,9 +19,9 @@ def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
|
||||
l2_reg = mx.zeros(
|
||||
1,
|
||||
)
|
||||
for k1, v1 in parameters.items():
|
||||
for k2, v2 in v1.items():
|
||||
l2_reg += mx.sum(v2["weight"] ** 2)
|
||||
for leaf in tree_flatten(parameters):
|
||||
l2_reg += mx.sum(leaf[1] ** 2)
|
||||
l2_reg = mx.sqrt(l2_reg)
|
||||
return l + weight_decay * l2_reg.item()
|
||||
return l
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user