fix comments before merge

This commit is contained in:
Tristan Bilot
2023-12-11 23:10:46 +01:00
parent b95e48e146
commit b606bfa6a7
5 changed files with 19 additions and 22 deletions

View File

@@ -16,13 +16,8 @@ def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
if weight_decay != 0.0:
assert parameters != None, "Model parameters missing for L2 reg."
l2_reg = mx.zeros(
1,
)
for leaf in tree_flatten(parameters):
l2_reg += mx.sum(leaf[1] ** 2)
l2_reg = mx.sqrt(l2_reg)
return l + weight_decay * l2_reg.item()
l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt()
return l + weight_decay * l2_reg
return l
@@ -39,10 +34,8 @@ def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
def main(args):
# Data loading
download_cora()
x, y, adj = load_data(args)
train_mask, val_mask, test_mask = train_val_test_mask(y, args.nb_classes)
train_mask, val_mask, test_mask = train_val_test_mask()
gcn = GCN(
x_dim=x.shape[-1],