mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
fix comments before merge
This commit is contained in:
13
gcn/main.py
13
gcn/main.py
@@ -16,13 +16,8 @@ def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
|
||||
if weight_decay != 0.0:
|
||||
assert parameters != None, "Model parameters missing for L2 reg."
|
||||
|
||||
l2_reg = mx.zeros(
|
||||
1,
|
||||
)
|
||||
for leaf in tree_flatten(parameters):
|
||||
l2_reg += mx.sum(leaf[1] ** 2)
|
||||
l2_reg = mx.sqrt(l2_reg)
|
||||
return l + weight_decay * l2_reg.item()
|
||||
l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt()
|
||||
return l + weight_decay * l2_reg
|
||||
return l
|
||||
|
||||
|
||||
@@ -39,10 +34,8 @@ def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
|
||||
def main(args):
|
||||
|
||||
# Data loading
|
||||
download_cora()
|
||||
|
||||
x, y, adj = load_data(args)
|
||||
train_mask, val_mask, test_mask = train_val_test_mask(y, args.nb_classes)
|
||||
train_mask, val_mask, test_mask = train_val_test_mask()
|
||||
|
||||
gcn = GCN(
|
||||
x_dim=x.shape[-1],
|
||||
|
Reference in New Issue
Block a user