add GCN implementation

This commit is contained in:
Tristan Bilot 2023-12-11 17:48:07 +01:00
parent ecd96acfe4
commit ed5a830626
5 changed files with 286 additions and 0 deletions

17
gcn/README.md Normal file
View File

@ -0,0 +1,17 @@
# Graph Convolutional Network
An example of [GCN](https://arxiv.org/pdf/1609.02907.pdf%EF%BC%89) implementation with MLX.
### Install requirements
First, install the few dependencies with `pip`.
```
pip install -r requirements.txt
```
### Run
To try the model, just run the `main.py` file. This will download the Cora dataset, run the training and testing.
```
python main.py
```

114
gcn/datasets.py Normal file
View File

@ -0,0 +1,114 @@
import os
import requests
import tarfile
import mlx.core as mx
import numpy as np
import scipy.sparse as sparse
"""
Preprocessing follows the same implementation as in:
https://github.com/tkipf/gcn
https://github.com/senadkurtisi/pytorch-GCN/tree/main
"""
def download_cora():
"""Downloads the cora dataset into a local cora folder."""
url = "https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz"
extract_to = "."
if os.path.exists(os.path.join(extract_to, "cora")):
return
response = requests.get(url, stream=True)
if response.status_code == 200:
file_path = os.path.join(extract_to, url.split("/")[-1])
# Write the file to local disk
with open(file_path, "wb") as file:
file.write(response.raw.read())
# Extract the .tgz file
with tarfile.open(file_path, "r:gz") as tar:
tar.extractall(path=extract_to)
print(f"Cora dataset extracted to {extract_to}")
os.remove(file_path)
def train_val_test_mask(labels, num_classes):
"""Splits the loaded dataset into train/validation/test sets."""
train_set = mx.array(list(range(140)))
validation_set = mx.array(list(range(200, 500)))
test_set = mx.array(list(range(500, 1500)))
return train_set, validation_set, test_set
def enumerate_labels(labels):
"""Converts the labels from the original
string form to the integer [0:MaxLabels-1]
"""
unique = list(set(labels))
labels = np.array([unique.index(label) for label in labels])
return labels
def normalize_adjacency(adj):
"""Normalizes the adjacency matrix according to the
paper by Kipf et al.
https://arxiv.org/pdf/1609.02907.pdf
"""
adj = adj + sparse.eye(adj.shape[0])
node_degrees = np.array(adj.sum(1))
node_degrees = np.power(node_degrees, -0.5).flatten()
node_degrees[np.isinf(node_degrees)] = 0.0
node_degrees[np.isnan(node_degrees)] = 0.0
degree_matrix = sparse.diags(node_degrees, dtype=np.float32)
adj = degree_matrix @ adj @ degree_matrix
return adj
def load_data(config):
"""Loads the Cora graph data into MLX array format."""
print("Loading Cora dataset...")
# Graph nodes
raw_nodes_data = np.genfromtxt(config.nodes_path, dtype="str")
raw_node_ids = raw_nodes_data[:, 0].astype(
"int32"
) # unique identifier of each node
raw_node_labels = raw_nodes_data[:, -1]
labels_enumerated = enumerate_labels(raw_node_labels) # target labels as integers
node_features = sparse.csr_matrix(raw_nodes_data[:, 1:-1], dtype="float32")
# Edges
ids_ordered = {raw_id: order for order, raw_id in enumerate(raw_node_ids)}
raw_edges_data = np.genfromtxt(config.edges_path, dtype="int32")
edges_ordered = np.array(
list(map(ids_ordered.get, raw_edges_data.flatten())), dtype="int32"
).reshape(raw_edges_data.shape)
# Adjacency matrix
adj = sparse.coo_matrix(
(np.ones(edges_ordered.shape[0]), (edges_ordered[:, 0], edges_ordered[:, 1])),
shape=(labels_enumerated.shape[0], labels_enumerated.shape[0]),
dtype=np.float32,
)
# Make the adjacency matrix symmetric
adj = adj + adj.T.multiply(adj.T > adj)
adj = normalize_adjacency(adj)
# Convert to mlx array
features = mx.array(node_features.toarray(), mx.float32)
labels = mx.array(labels_enumerated, mx.int32)
adj = mx.array(adj.toarray())
print("Dataset loaded.")
return features, labels, adj

31
gcn/gcn.py Normal file
View File

@ -0,0 +1,31 @@
import mlx.nn as nn
class GCNLayer(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super(GCNLayer, self).__init__()
self.linear = nn.Linear(in_features, out_features, bias)
def __call__(self, x, adj):
x = self.linear(x)
return adj @ x
class GCN(nn.Module):
def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):
super(GCN, self).__init__()
layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]
self.gcn_layers = [
GCNLayer(in_dim, out_dim, bias)
for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])
]
self.dropout = nn.Dropout(p=dropout)
def __call__(self, x, adj):
for layer in self.gcn_layers[:-1]:
x = nn.relu(layer(x, adj))
x = self.dropout(x)
x = self.gcn_layers[-1](x, adj)
return x

120
gcn/main.py Normal file
View File

@ -0,0 +1,120 @@
from argparse import ArgumentParser
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.nn.losses import cross_entropy
from datasets import download_cora, load_data, train_val_test_mask
from gcn import GCN
def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
l = mx.mean(nn.losses.cross_entropy(y_hat, y))
if weight_decay != 0.0:
assert parameters != None, "Model parameters missing for L2 reg."
l2_reg = mx.zeros(
1,
)
for k1, v1 in parameters.items():
for k2, v2 in v1.items():
l2_reg += mx.sum(v2["weight"] ** 2)
return l + weight_decay * l2_reg.item()
return l
def eval_fn(x, y):
return mx.mean(mx.argmax(x, axis=1) == y)
def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
y_hat = gcn(x, adj)
loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters())
return loss, y_hat
def main(args):
# Data loading
download_cora()
x, y, adj = load_data(args)
train_mask, val_mask, test_mask = train_val_test_mask(y, args.nb_classes)
gcn = GCN(
x_dim=x.shape[-1],
h_dim=args.hidden_dim,
out_dim=args.nb_classes,
nb_layers=args.nb_layers,
dropout=args.dropout,
bias=args.bias,
)
mx.eval(gcn.parameters())
optimizer = optim.Adam(learning_rate=args.lr)
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
best_val_loss = float("inf")
cnt = 0
# Training loop
for epoch in range(args.epochs):
# Loss
(loss, y_hat), grads = loss_and_grad_fn(
gcn, x, adj, y, train_mask, args.weight_decay
)
optimizer.update(gcn, grads)
mx.eval(gcn.parameters(), optimizer.state)
# Validation
val_loss = loss_fn(y_hat[val_mask], y[val_mask])
val_acc = eval_fn(y_hat[val_mask], y[val_mask])
# Early stopping
if val_loss < best_val_loss:
best_val_loss = val_loss
cnt = 0
else:
cnt += 1
if cnt == args.patience:
break
print(
" | ".join(
[
f"Epoch: {epoch:3d}",
f"Train loss: {loss.item():.3f}",
f"Val loss: {val_loss.item():.3f}",
f"Val acc: {val_acc.item():.2f}",
]
)
)
# Test
test_y_hat = gcn(x, adj)
test_loss = loss_fn(y_hat[test_mask], y[test_mask])
test_acc = eval_fn(y_hat[test_mask], y[test_mask])
print(f"Test loss: {test_loss.item():.3f} | Test acc: {test_acc.item():.2f}")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--nodes_path", type=str, default="cora/cora.content")
parser.add_argument("--edges_path", type=str, default="cora/cora.cites")
parser.add_argument("--hidden_dim", type=int, default=20)
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument("--nb_layers", type=int, default=2)
parser.add_argument("--nb_classes", type=int, default=7)
parser.add_argument("--bias", type=bool, default=True)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--weight_decay", type=float, default=0.0)
parser.add_argument("--patience", type=int, default=20)
parser.add_argument("--epochs", type=int, default=100)
args = parser.parse_args()
main(args)

4
gcn/requirements.txt Normal file
View File

@ -0,0 +1,4 @@
mlx==0.0.4
numpy==1.26.2
scipy==1.11.4
requests==2.31.0