mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
@@ -32,7 +32,6 @@ def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
# Data loading
|
||||
x, y, adj = load_data(args)
|
||||
train_mask, val_mask, test_mask = train_val_test_mask()
|
||||
@@ -55,7 +54,6 @@ def main(args):
|
||||
|
||||
# Training loop
|
||||
for epoch in range(args.epochs):
|
||||
|
||||
# Loss
|
||||
(loss, y_hat), grads = loss_and_grad_fn(
|
||||
gcn, x, adj, y, train_mask, args.weight_decay
|
||||
@@ -96,7 +94,6 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--nodes_path", type=str, default="cora/cora.content")
|
||||
parser.add_argument("--edges_path", type=str, default="cora/cora.cites")
|
||||
|
Reference in New Issue
Block a user