mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
use tree_flatten within L2 regularization
This commit is contained in:
parent
ed5a830626
commit
b95e48e146
@ -4,6 +4,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import mlx.optimizers as optim
|
import mlx.optimizers as optim
|
||||||
from mlx.nn.losses import cross_entropy
|
from mlx.nn.losses import cross_entropy
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
from datasets import download_cora, load_data, train_val_test_mask
|
from datasets import download_cora, load_data, train_val_test_mask
|
||||||
from gcn import GCN
|
from gcn import GCN
|
||||||
@ -18,9 +19,9 @@ def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
|
|||||||
l2_reg = mx.zeros(
|
l2_reg = mx.zeros(
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
for k1, v1 in parameters.items():
|
for leaf in tree_flatten(parameters):
|
||||||
for k2, v2 in v1.items():
|
l2_reg += mx.sum(leaf[1] ** 2)
|
||||||
l2_reg += mx.sum(v2["weight"] ** 2)
|
l2_reg = mx.sqrt(l2_reg)
|
||||||
return l + weight_decay * l2_reg.item()
|
return l + weight_decay * l2_reg.item()
|
||||||
return l
|
return l
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user