From ed5a83062628d90f5b1a0ab57a5a6a700e740720 Mon Sep 17 00:00:00 2001 From: Tristan Bilot Date: Mon, 11 Dec 2023 17:48:07 +0100 Subject: [PATCH 1/3] add GCN implementation --- gcn/README.md | 17 ++++++ gcn/datasets.py | 114 ++++++++++++++++++++++++++++++++++++++++ gcn/gcn.py | 31 +++++++++++ gcn/main.py | 120 +++++++++++++++++++++++++++++++++++++++++++ gcn/requirements.txt | 4 ++ 5 files changed, 286 insertions(+) create mode 100644 gcn/README.md create mode 100644 gcn/datasets.py create mode 100644 gcn/gcn.py create mode 100644 gcn/main.py create mode 100644 gcn/requirements.txt diff --git a/gcn/README.md b/gcn/README.md new file mode 100644 index 00000000..fafcb8e1 --- /dev/null +++ b/gcn/README.md @@ -0,0 +1,17 @@ +# Graph Convolutional Network + +An example of [GCN](https://arxiv.org/pdf/1609.02907.pdf%EF%BC%89) implementation with MLX. + +### Install requirements +First, install the few dependencies with `pip`. + +``` +pip install -r requirements.txt +``` + +### Run +To try the model, just run the `main.py` file. This will download the Cora dataset, run the training and testing. + +``` +python main.py +``` diff --git a/gcn/datasets.py b/gcn/datasets.py new file mode 100644 index 00000000..56d87e71 --- /dev/null +++ b/gcn/datasets.py @@ -0,0 +1,114 @@ +import os +import requests +import tarfile + +import mlx.core as mx +import numpy as np +import scipy.sparse as sparse + +""" +Preprocessing follows the same implementation as in: +https://github.com/tkipf/gcn +https://github.com/senadkurtisi/pytorch-GCN/tree/main +""" + + +def download_cora(): + """Downloads the cora dataset into a local cora folder.""" + + url = "https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz" + extract_to = "." + + if os.path.exists(os.path.join(extract_to, "cora")): + return + + response = requests.get(url, stream=True) + if response.status_code == 200: + file_path = os.path.join(extract_to, url.split("/")[-1]) + + # Write the file to local disk + with open(file_path, "wb") as file: + file.write(response.raw.read()) + + # Extract the .tgz file + with tarfile.open(file_path, "r:gz") as tar: + tar.extractall(path=extract_to) + print(f"Cora dataset extracted to {extract_to}") + + os.remove(file_path) + + +def train_val_test_mask(labels, num_classes): + """Splits the loaded dataset into train/validation/test sets.""" + + train_set = mx.array(list(range(140))) + validation_set = mx.array(list(range(200, 500))) + test_set = mx.array(list(range(500, 1500))) + + return train_set, validation_set, test_set + + +def enumerate_labels(labels): + """Converts the labels from the original + string form to the integer [0:MaxLabels-1] + """ + unique = list(set(labels)) + labels = np.array([unique.index(label) for label in labels]) + return labels + + +def normalize_adjacency(adj): + """Normalizes the adjacency matrix according to the + paper by Kipf et al. + https://arxiv.org/pdf/1609.02907.pdf + """ + adj = adj + sparse.eye(adj.shape[0]) + + node_degrees = np.array(adj.sum(1)) + node_degrees = np.power(node_degrees, -0.5).flatten() + node_degrees[np.isinf(node_degrees)] = 0.0 + node_degrees[np.isnan(node_degrees)] = 0.0 + degree_matrix = sparse.diags(node_degrees, dtype=np.float32) + + adj = degree_matrix @ adj @ degree_matrix + return adj + + +def load_data(config): + """Loads the Cora graph data into MLX array format.""" + print("Loading Cora dataset...") + + # Graph nodes + raw_nodes_data = np.genfromtxt(config.nodes_path, dtype="str") + raw_node_ids = raw_nodes_data[:, 0].astype( + "int32" + ) # unique identifier of each node + raw_node_labels = raw_nodes_data[:, -1] + labels_enumerated = enumerate_labels(raw_node_labels) # target labels as integers + node_features = sparse.csr_matrix(raw_nodes_data[:, 1:-1], dtype="float32") + + # Edges + ids_ordered = {raw_id: order for order, raw_id in enumerate(raw_node_ids)} + raw_edges_data = np.genfromtxt(config.edges_path, dtype="int32") + edges_ordered = np.array( + list(map(ids_ordered.get, raw_edges_data.flatten())), dtype="int32" + ).reshape(raw_edges_data.shape) + + # Adjacency matrix + adj = sparse.coo_matrix( + (np.ones(edges_ordered.shape[0]), (edges_ordered[:, 0], edges_ordered[:, 1])), + shape=(labels_enumerated.shape[0], labels_enumerated.shape[0]), + dtype=np.float32, + ) + + # Make the adjacency matrix symmetric + adj = adj + adj.T.multiply(adj.T > adj) + adj = normalize_adjacency(adj) + + # Convert to mlx array + features = mx.array(node_features.toarray(), mx.float32) + labels = mx.array(labels_enumerated, mx.int32) + adj = mx.array(adj.toarray()) + + print("Dataset loaded.") + return features, labels, adj diff --git a/gcn/gcn.py b/gcn/gcn.py new file mode 100644 index 00000000..91588b35 --- /dev/null +++ b/gcn/gcn.py @@ -0,0 +1,31 @@ +import mlx.nn as nn + + +class GCNLayer(nn.Module): + def __init__(self, in_features, out_features, bias=True): + super(GCNLayer, self).__init__() + self.linear = nn.Linear(in_features, out_features, bias) + + def __call__(self, x, adj): + x = self.linear(x) + return adj @ x + + +class GCN(nn.Module): + def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True): + super(GCN, self).__init__() + + layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim] + self.gcn_layers = [ + GCNLayer(in_dim, out_dim, bias) + for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:]) + ] + self.dropout = nn.Dropout(p=dropout) + + def __call__(self, x, adj): + for layer in self.gcn_layers[:-1]: + x = nn.relu(layer(x, adj)) + x = self.dropout(x) + + x = self.gcn_layers[-1](x, adj) + return x diff --git a/gcn/main.py b/gcn/main.py new file mode 100644 index 00000000..5081d10a --- /dev/null +++ b/gcn/main.py @@ -0,0 +1,120 @@ +from argparse import ArgumentParser + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +from mlx.nn.losses import cross_entropy + +from datasets import download_cora, load_data, train_val_test_mask +from gcn import GCN + + +def loss_fn(y_hat, y, weight_decay=0.0, parameters=None): + l = mx.mean(nn.losses.cross_entropy(y_hat, y)) + + if weight_decay != 0.0: + assert parameters != None, "Model parameters missing for L2 reg." + + l2_reg = mx.zeros( + 1, + ) + for k1, v1 in parameters.items(): + for k2, v2 in v1.items(): + l2_reg += mx.sum(v2["weight"] ** 2) + return l + weight_decay * l2_reg.item() + return l + + +def eval_fn(x, y): + return mx.mean(mx.argmax(x, axis=1) == y) + + +def forward_fn(gcn, x, adj, y, train_mask, weight_decay): + y_hat = gcn(x, adj) + loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters()) + return loss, y_hat + + +def main(args): + + # Data loading + download_cora() + + x, y, adj = load_data(args) + train_mask, val_mask, test_mask = train_val_test_mask(y, args.nb_classes) + + gcn = GCN( + x_dim=x.shape[-1], + h_dim=args.hidden_dim, + out_dim=args.nb_classes, + nb_layers=args.nb_layers, + dropout=args.dropout, + bias=args.bias, + ) + mx.eval(gcn.parameters()) + + optimizer = optim.Adam(learning_rate=args.lr) + loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn) + + best_val_loss = float("inf") + cnt = 0 + + # Training loop + for epoch in range(args.epochs): + + # Loss + (loss, y_hat), grads = loss_and_grad_fn( + gcn, x, adj, y, train_mask, args.weight_decay + ) + optimizer.update(gcn, grads) + mx.eval(gcn.parameters(), optimizer.state) + + # Validation + val_loss = loss_fn(y_hat[val_mask], y[val_mask]) + val_acc = eval_fn(y_hat[val_mask], y[val_mask]) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + cnt = 0 + else: + cnt += 1 + if cnt == args.patience: + break + + print( + " | ".join( + [ + f"Epoch: {epoch:3d}", + f"Train loss: {loss.item():.3f}", + f"Val loss: {val_loss.item():.3f}", + f"Val acc: {val_acc.item():.2f}", + ] + ) + ) + + # Test + test_y_hat = gcn(x, adj) + test_loss = loss_fn(y_hat[test_mask], y[test_mask]) + test_acc = eval_fn(y_hat[test_mask], y[test_mask]) + + print(f"Test loss: {test_loss.item():.3f} | Test acc: {test_acc.item():.2f}") + + +if __name__ == "__main__": + + parser = ArgumentParser() + parser.add_argument("--nodes_path", type=str, default="cora/cora.content") + parser.add_argument("--edges_path", type=str, default="cora/cora.cites") + parser.add_argument("--hidden_dim", type=int, default=20) + parser.add_argument("--dropout", type=float, default=0.5) + parser.add_argument("--nb_layers", type=int, default=2) + parser.add_argument("--nb_classes", type=int, default=7) + parser.add_argument("--bias", type=bool, default=True) + parser.add_argument("--lr", type=float, default=0.001) + parser.add_argument("--weight_decay", type=float, default=0.0) + parser.add_argument("--patience", type=int, default=20) + parser.add_argument("--epochs", type=int, default=100) + args = parser.parse_args() + + main(args) diff --git a/gcn/requirements.txt b/gcn/requirements.txt new file mode 100644 index 00000000..d0671eec --- /dev/null +++ b/gcn/requirements.txt @@ -0,0 +1,4 @@ +mlx==0.0.4 +numpy==1.26.2 +scipy==1.11.4 +requests==2.31.0 From b95e48e1467d20de9491e319bb5f984655db08c4 Mon Sep 17 00:00:00 2001 From: Tristan Bilot Date: Mon, 11 Dec 2023 20:15:11 +0100 Subject: [PATCH 2/3] use tree_flatten within L2 regularization --- gcn/main.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/gcn/main.py b/gcn/main.py index 5081d10a..24f07f4a 100644 --- a/gcn/main.py +++ b/gcn/main.py @@ -4,6 +4,7 @@ import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim from mlx.nn.losses import cross_entropy +from mlx.utils import tree_flatten from datasets import download_cora, load_data, train_val_test_mask from gcn import GCN @@ -18,9 +19,9 @@ def loss_fn(y_hat, y, weight_decay=0.0, parameters=None): l2_reg = mx.zeros( 1, ) - for k1, v1 in parameters.items(): - for k2, v2 in v1.items(): - l2_reg += mx.sum(v2["weight"] ** 2) + for leaf in tree_flatten(parameters): + l2_reg += mx.sum(leaf[1] ** 2) + l2_reg = mx.sqrt(l2_reg) return l + weight_decay * l2_reg.item() return l From b606bfa6a758d1c3e3d984dc47f566a6d0241273 Mon Sep 17 00:00:00 2001 From: Tristan Bilot Date: Mon, 11 Dec 2023 23:10:46 +0100 Subject: [PATCH 3/3] fix comments before merge --- gcn/.gitignore | 1 + gcn/README.md | 2 +- gcn/datasets.py | 17 ++++++++++------- gcn/main.py | 13 +++---------- gcn/requirements.txt | 8 ++++---- 5 files changed, 19 insertions(+), 22 deletions(-) create mode 100644 gcn/.gitignore diff --git a/gcn/.gitignore b/gcn/.gitignore new file mode 100644 index 00000000..ab0d10a0 --- /dev/null +++ b/gcn/.gitignore @@ -0,0 +1 @@ +cora/ diff --git a/gcn/README.md b/gcn/README.md index fafcb8e1..3da4cebc 100644 --- a/gcn/README.md +++ b/gcn/README.md @@ -1,6 +1,6 @@ # Graph Convolutional Network -An example of [GCN](https://arxiv.org/pdf/1609.02907.pdf%EF%BC%89) implementation with MLX. +An example of [GCN](https://arxiv.org/abs/1609.02907) implementation with MLX. ### Install requirements First, install the few dependencies with `pip`. diff --git a/gcn/datasets.py b/gcn/datasets.py index 56d87e71..d5ab59ad 100644 --- a/gcn/datasets.py +++ b/gcn/datasets.py @@ -38,12 +38,12 @@ def download_cora(): os.remove(file_path) -def train_val_test_mask(labels, num_classes): +def train_val_test_mask(): """Splits the loaded dataset into train/validation/test sets.""" - train_set = mx.array(list(range(140))) - validation_set = mx.array(list(range(200, 500))) - test_set = mx.array(list(range(500, 1500))) + train_set = mx.arange(140) + validation_set = mx.arange(200, 500) + test_set = mx.arange(500, 1500) return train_set, validation_set, test_set @@ -52,15 +52,15 @@ def enumerate_labels(labels): """Converts the labels from the original string form to the integer [0:MaxLabels-1] """ - unique = list(set(labels)) - labels = np.array([unique.index(label) for label in labels]) + label_map = {v: e for e, v in enumerate(set(labels))} + labels = np.array([label_map[label] for label in labels]) return labels def normalize_adjacency(adj): """Normalizes the adjacency matrix according to the paper by Kipf et al. - https://arxiv.org/pdf/1609.02907.pdf + https://arxiv.org/abs/1609.02907 """ adj = adj + sparse.eye(adj.shape[0]) @@ -78,6 +78,9 @@ def load_data(config): """Loads the Cora graph data into MLX array format.""" print("Loading Cora dataset...") + # Download dataset files + download_cora() + # Graph nodes raw_nodes_data = np.genfromtxt(config.nodes_path, dtype="str") raw_node_ids = raw_nodes_data[:, 0].astype( diff --git a/gcn/main.py b/gcn/main.py index 24f07f4a..66c29550 100644 --- a/gcn/main.py +++ b/gcn/main.py @@ -16,13 +16,8 @@ def loss_fn(y_hat, y, weight_decay=0.0, parameters=None): if weight_decay != 0.0: assert parameters != None, "Model parameters missing for L2 reg." - l2_reg = mx.zeros( - 1, - ) - for leaf in tree_flatten(parameters): - l2_reg += mx.sum(leaf[1] ** 2) - l2_reg = mx.sqrt(l2_reg) - return l + weight_decay * l2_reg.item() + l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt() + return l + weight_decay * l2_reg return l @@ -39,10 +34,8 @@ def forward_fn(gcn, x, adj, y, train_mask, weight_decay): def main(args): # Data loading - download_cora() - x, y, adj = load_data(args) - train_mask, val_mask, test_mask = train_val_test_mask(y, args.nb_classes) + train_mask, val_mask, test_mask = train_val_test_mask() gcn = GCN( x_dim=x.shape[-1], diff --git a/gcn/requirements.txt b/gcn/requirements.txt index d0671eec..3d061551 100644 --- a/gcn/requirements.txt +++ b/gcn/requirements.txt @@ -1,4 +1,4 @@ -mlx==0.0.4 -numpy==1.26.2 -scipy==1.11.4 -requests==2.31.0 +mlx>=0.0.4 +numpy>=1.26.2 +scipy>=1.11.4 +requests>=2.31.0