From ed5a83062628d90f5b1a0ab57a5a6a700e740720 Mon Sep 17 00:00:00 2001 From: Tristan Bilot Date: Mon, 11 Dec 2023 17:48:07 +0100 Subject: [PATCH] add GCN implementation --- gcn/README.md | 17 ++++++ gcn/datasets.py | 114 ++++++++++++++++++++++++++++++++++++++++ gcn/gcn.py | 31 +++++++++++ gcn/main.py | 120 +++++++++++++++++++++++++++++++++++++++++++ gcn/requirements.txt | 4 ++ 5 files changed, 286 insertions(+) create mode 100644 gcn/README.md create mode 100644 gcn/datasets.py create mode 100644 gcn/gcn.py create mode 100644 gcn/main.py create mode 100644 gcn/requirements.txt diff --git a/gcn/README.md b/gcn/README.md new file mode 100644 index 00000000..fafcb8e1 --- /dev/null +++ b/gcn/README.md @@ -0,0 +1,17 @@ +# Graph Convolutional Network + +An example of [GCN](https://arxiv.org/pdf/1609.02907.pdf%EF%BC%89) implementation with MLX. + +### Install requirements +First, install the few dependencies with `pip`. + +``` +pip install -r requirements.txt +``` + +### Run +To try the model, just run the `main.py` file. This will download the Cora dataset, run the training and testing. + +``` +python main.py +``` diff --git a/gcn/datasets.py b/gcn/datasets.py new file mode 100644 index 00000000..56d87e71 --- /dev/null +++ b/gcn/datasets.py @@ -0,0 +1,114 @@ +import os +import requests +import tarfile + +import mlx.core as mx +import numpy as np +import scipy.sparse as sparse + +""" +Preprocessing follows the same implementation as in: +https://github.com/tkipf/gcn +https://github.com/senadkurtisi/pytorch-GCN/tree/main +""" + + +def download_cora(): + """Downloads the cora dataset into a local cora folder.""" + + url = "https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz" + extract_to = "." + + if os.path.exists(os.path.join(extract_to, "cora")): + return + + response = requests.get(url, stream=True) + if response.status_code == 200: + file_path = os.path.join(extract_to, url.split("/")[-1]) + + # Write the file to local disk + with open(file_path, "wb") as file: + file.write(response.raw.read()) + + # Extract the .tgz file + with tarfile.open(file_path, "r:gz") as tar: + tar.extractall(path=extract_to) + print(f"Cora dataset extracted to {extract_to}") + + os.remove(file_path) + + +def train_val_test_mask(labels, num_classes): + """Splits the loaded dataset into train/validation/test sets.""" + + train_set = mx.array(list(range(140))) + validation_set = mx.array(list(range(200, 500))) + test_set = mx.array(list(range(500, 1500))) + + return train_set, validation_set, test_set + + +def enumerate_labels(labels): + """Converts the labels from the original + string form to the integer [0:MaxLabels-1] + """ + unique = list(set(labels)) + labels = np.array([unique.index(label) for label in labels]) + return labels + + +def normalize_adjacency(adj): + """Normalizes the adjacency matrix according to the + paper by Kipf et al. + https://arxiv.org/pdf/1609.02907.pdf + """ + adj = adj + sparse.eye(adj.shape[0]) + + node_degrees = np.array(adj.sum(1)) + node_degrees = np.power(node_degrees, -0.5).flatten() + node_degrees[np.isinf(node_degrees)] = 0.0 + node_degrees[np.isnan(node_degrees)] = 0.0 + degree_matrix = sparse.diags(node_degrees, dtype=np.float32) + + adj = degree_matrix @ adj @ degree_matrix + return adj + + +def load_data(config): + """Loads the Cora graph data into MLX array format.""" + print("Loading Cora dataset...") + + # Graph nodes + raw_nodes_data = np.genfromtxt(config.nodes_path, dtype="str") + raw_node_ids = raw_nodes_data[:, 0].astype( + "int32" + ) # unique identifier of each node + raw_node_labels = raw_nodes_data[:, -1] + labels_enumerated = enumerate_labels(raw_node_labels) # target labels as integers + node_features = sparse.csr_matrix(raw_nodes_data[:, 1:-1], dtype="float32") + + # Edges + ids_ordered = {raw_id: order for order, raw_id in enumerate(raw_node_ids)} + raw_edges_data = np.genfromtxt(config.edges_path, dtype="int32") + edges_ordered = np.array( + list(map(ids_ordered.get, raw_edges_data.flatten())), dtype="int32" + ).reshape(raw_edges_data.shape) + + # Adjacency matrix + adj = sparse.coo_matrix( + (np.ones(edges_ordered.shape[0]), (edges_ordered[:, 0], edges_ordered[:, 1])), + shape=(labels_enumerated.shape[0], labels_enumerated.shape[0]), + dtype=np.float32, + ) + + # Make the adjacency matrix symmetric + adj = adj + adj.T.multiply(adj.T > adj) + adj = normalize_adjacency(adj) + + # Convert to mlx array + features = mx.array(node_features.toarray(), mx.float32) + labels = mx.array(labels_enumerated, mx.int32) + adj = mx.array(adj.toarray()) + + print("Dataset loaded.") + return features, labels, adj diff --git a/gcn/gcn.py b/gcn/gcn.py new file mode 100644 index 00000000..91588b35 --- /dev/null +++ b/gcn/gcn.py @@ -0,0 +1,31 @@ +import mlx.nn as nn + + +class GCNLayer(nn.Module): + def __init__(self, in_features, out_features, bias=True): + super(GCNLayer, self).__init__() + self.linear = nn.Linear(in_features, out_features, bias) + + def __call__(self, x, adj): + x = self.linear(x) + return adj @ x + + +class GCN(nn.Module): + def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True): + super(GCN, self).__init__() + + layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim] + self.gcn_layers = [ + GCNLayer(in_dim, out_dim, bias) + for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:]) + ] + self.dropout = nn.Dropout(p=dropout) + + def __call__(self, x, adj): + for layer in self.gcn_layers[:-1]: + x = nn.relu(layer(x, adj)) + x = self.dropout(x) + + x = self.gcn_layers[-1](x, adj) + return x diff --git a/gcn/main.py b/gcn/main.py new file mode 100644 index 00000000..5081d10a --- /dev/null +++ b/gcn/main.py @@ -0,0 +1,120 @@ +from argparse import ArgumentParser + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +from mlx.nn.losses import cross_entropy + +from datasets import download_cora, load_data, train_val_test_mask +from gcn import GCN + + +def loss_fn(y_hat, y, weight_decay=0.0, parameters=None): + l = mx.mean(nn.losses.cross_entropy(y_hat, y)) + + if weight_decay != 0.0: + assert parameters != None, "Model parameters missing for L2 reg." + + l2_reg = mx.zeros( + 1, + ) + for k1, v1 in parameters.items(): + for k2, v2 in v1.items(): + l2_reg += mx.sum(v2["weight"] ** 2) + return l + weight_decay * l2_reg.item() + return l + + +def eval_fn(x, y): + return mx.mean(mx.argmax(x, axis=1) == y) + + +def forward_fn(gcn, x, adj, y, train_mask, weight_decay): + y_hat = gcn(x, adj) + loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters()) + return loss, y_hat + + +def main(args): + + # Data loading + download_cora() + + x, y, adj = load_data(args) + train_mask, val_mask, test_mask = train_val_test_mask(y, args.nb_classes) + + gcn = GCN( + x_dim=x.shape[-1], + h_dim=args.hidden_dim, + out_dim=args.nb_classes, + nb_layers=args.nb_layers, + dropout=args.dropout, + bias=args.bias, + ) + mx.eval(gcn.parameters()) + + optimizer = optim.Adam(learning_rate=args.lr) + loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn) + + best_val_loss = float("inf") + cnt = 0 + + # Training loop + for epoch in range(args.epochs): + + # Loss + (loss, y_hat), grads = loss_and_grad_fn( + gcn, x, adj, y, train_mask, args.weight_decay + ) + optimizer.update(gcn, grads) + mx.eval(gcn.parameters(), optimizer.state) + + # Validation + val_loss = loss_fn(y_hat[val_mask], y[val_mask]) + val_acc = eval_fn(y_hat[val_mask], y[val_mask]) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + cnt = 0 + else: + cnt += 1 + if cnt == args.patience: + break + + print( + " | ".join( + [ + f"Epoch: {epoch:3d}", + f"Train loss: {loss.item():.3f}", + f"Val loss: {val_loss.item():.3f}", + f"Val acc: {val_acc.item():.2f}", + ] + ) + ) + + # Test + test_y_hat = gcn(x, adj) + test_loss = loss_fn(y_hat[test_mask], y[test_mask]) + test_acc = eval_fn(y_hat[test_mask], y[test_mask]) + + print(f"Test loss: {test_loss.item():.3f} | Test acc: {test_acc.item():.2f}") + + +if __name__ == "__main__": + + parser = ArgumentParser() + parser.add_argument("--nodes_path", type=str, default="cora/cora.content") + parser.add_argument("--edges_path", type=str, default="cora/cora.cites") + parser.add_argument("--hidden_dim", type=int, default=20) + parser.add_argument("--dropout", type=float, default=0.5) + parser.add_argument("--nb_layers", type=int, default=2) + parser.add_argument("--nb_classes", type=int, default=7) + parser.add_argument("--bias", type=bool, default=True) + parser.add_argument("--lr", type=float, default=0.001) + parser.add_argument("--weight_decay", type=float, default=0.0) + parser.add_argument("--patience", type=int, default=20) + parser.add_argument("--epochs", type=int, default=100) + args = parser.parse_args() + + main(args) diff --git a/gcn/requirements.txt b/gcn/requirements.txt new file mode 100644 index 00000000..d0671eec --- /dev/null +++ b/gcn/requirements.txt @@ -0,0 +1,4 @@ +mlx==0.0.4 +numpy==1.26.2 +scipy==1.11.4 +requests==2.31.0