From b606bfa6a758d1c3e3d984dc47f566a6d0241273 Mon Sep 17 00:00:00 2001 From: Tristan Bilot Date: Mon, 11 Dec 2023 23:10:46 +0100 Subject: [PATCH] fix comments before merge --- gcn/.gitignore | 1 + gcn/README.md | 2 +- gcn/datasets.py | 17 ++++++++++------- gcn/main.py | 13 +++---------- gcn/requirements.txt | 8 ++++---- 5 files changed, 19 insertions(+), 22 deletions(-) create mode 100644 gcn/.gitignore diff --git a/gcn/.gitignore b/gcn/.gitignore new file mode 100644 index 00000000..ab0d10a0 --- /dev/null +++ b/gcn/.gitignore @@ -0,0 +1 @@ +cora/ diff --git a/gcn/README.md b/gcn/README.md index fafcb8e1..3da4cebc 100644 --- a/gcn/README.md +++ b/gcn/README.md @@ -1,6 +1,6 @@ # Graph Convolutional Network -An example of [GCN](https://arxiv.org/pdf/1609.02907.pdf%EF%BC%89) implementation with MLX. +An example of [GCN](https://arxiv.org/abs/1609.02907) implementation with MLX. ### Install requirements First, install the few dependencies with `pip`. diff --git a/gcn/datasets.py b/gcn/datasets.py index 56d87e71..d5ab59ad 100644 --- a/gcn/datasets.py +++ b/gcn/datasets.py @@ -38,12 +38,12 @@ def download_cora(): os.remove(file_path) -def train_val_test_mask(labels, num_classes): +def train_val_test_mask(): """Splits the loaded dataset into train/validation/test sets.""" - train_set = mx.array(list(range(140))) - validation_set = mx.array(list(range(200, 500))) - test_set = mx.array(list(range(500, 1500))) + train_set = mx.arange(140) + validation_set = mx.arange(200, 500) + test_set = mx.arange(500, 1500) return train_set, validation_set, test_set @@ -52,15 +52,15 @@ def enumerate_labels(labels): """Converts the labels from the original string form to the integer [0:MaxLabels-1] """ - unique = list(set(labels)) - labels = np.array([unique.index(label) for label in labels]) + label_map = {v: e for e, v in enumerate(set(labels))} + labels = np.array([label_map[label] for label in labels]) return labels def normalize_adjacency(adj): """Normalizes the adjacency matrix according to the paper by Kipf et al. - https://arxiv.org/pdf/1609.02907.pdf + https://arxiv.org/abs/1609.02907 """ adj = adj + sparse.eye(adj.shape[0]) @@ -78,6 +78,9 @@ def load_data(config): """Loads the Cora graph data into MLX array format.""" print("Loading Cora dataset...") + # Download dataset files + download_cora() + # Graph nodes raw_nodes_data = np.genfromtxt(config.nodes_path, dtype="str") raw_node_ids = raw_nodes_data[:, 0].astype( diff --git a/gcn/main.py b/gcn/main.py index 24f07f4a..66c29550 100644 --- a/gcn/main.py +++ b/gcn/main.py @@ -16,13 +16,8 @@ def loss_fn(y_hat, y, weight_decay=0.0, parameters=None): if weight_decay != 0.0: assert parameters != None, "Model parameters missing for L2 reg." - l2_reg = mx.zeros( - 1, - ) - for leaf in tree_flatten(parameters): - l2_reg += mx.sum(leaf[1] ** 2) - l2_reg = mx.sqrt(l2_reg) - return l + weight_decay * l2_reg.item() + l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt() + return l + weight_decay * l2_reg return l @@ -39,10 +34,8 @@ def forward_fn(gcn, x, adj, y, train_mask, weight_decay): def main(args): # Data loading - download_cora() - x, y, adj = load_data(args) - train_mask, val_mask, test_mask = train_val_test_mask(y, args.nb_classes) + train_mask, val_mask, test_mask = train_val_test_mask() gcn = GCN( x_dim=x.shape[-1], diff --git a/gcn/requirements.txt b/gcn/requirements.txt index d0671eec..3d061551 100644 --- a/gcn/requirements.txt +++ b/gcn/requirements.txt @@ -1,4 +1,4 @@ -mlx==0.0.4 -numpy==1.26.2 -scipy==1.11.4 -requests==2.31.0 +mlx>=0.0.4 +numpy>=1.26.2 +scipy>=1.11.4 +requests>=2.31.0