mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
add GCN implementation
This commit is contained in:
parent
ecd96acfe4
commit
ed5a830626
17
gcn/README.md
Normal file
17
gcn/README.md
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# Graph Convolutional Network
|
||||||
|
|
||||||
|
An example of [GCN](https://arxiv.org/pdf/1609.02907.pdf%EF%BC%89) implementation with MLX.
|
||||||
|
|
||||||
|
### Install requirements
|
||||||
|
First, install the few dependencies with `pip`.
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run
|
||||||
|
To try the model, just run the `main.py` file. This will download the Cora dataset, run the training and testing.
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py
|
||||||
|
```
|
114
gcn/datasets.py
Normal file
114
gcn/datasets.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import tarfile
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import scipy.sparse as sparse
|
||||||
|
|
||||||
|
"""
|
||||||
|
Preprocessing follows the same implementation as in:
|
||||||
|
https://github.com/tkipf/gcn
|
||||||
|
https://github.com/senadkurtisi/pytorch-GCN/tree/main
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def download_cora():
|
||||||
|
"""Downloads the cora dataset into a local cora folder."""
|
||||||
|
|
||||||
|
url = "https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz"
|
||||||
|
extract_to = "."
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(extract_to, "cora")):
|
||||||
|
return
|
||||||
|
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
if response.status_code == 200:
|
||||||
|
file_path = os.path.join(extract_to, url.split("/")[-1])
|
||||||
|
|
||||||
|
# Write the file to local disk
|
||||||
|
with open(file_path, "wb") as file:
|
||||||
|
file.write(response.raw.read())
|
||||||
|
|
||||||
|
# Extract the .tgz file
|
||||||
|
with tarfile.open(file_path, "r:gz") as tar:
|
||||||
|
tar.extractall(path=extract_to)
|
||||||
|
print(f"Cora dataset extracted to {extract_to}")
|
||||||
|
|
||||||
|
os.remove(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def train_val_test_mask(labels, num_classes):
|
||||||
|
"""Splits the loaded dataset into train/validation/test sets."""
|
||||||
|
|
||||||
|
train_set = mx.array(list(range(140)))
|
||||||
|
validation_set = mx.array(list(range(200, 500)))
|
||||||
|
test_set = mx.array(list(range(500, 1500)))
|
||||||
|
|
||||||
|
return train_set, validation_set, test_set
|
||||||
|
|
||||||
|
|
||||||
|
def enumerate_labels(labels):
|
||||||
|
"""Converts the labels from the original
|
||||||
|
string form to the integer [0:MaxLabels-1]
|
||||||
|
"""
|
||||||
|
unique = list(set(labels))
|
||||||
|
labels = np.array([unique.index(label) for label in labels])
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_adjacency(adj):
|
||||||
|
"""Normalizes the adjacency matrix according to the
|
||||||
|
paper by Kipf et al.
|
||||||
|
https://arxiv.org/pdf/1609.02907.pdf
|
||||||
|
"""
|
||||||
|
adj = adj + sparse.eye(adj.shape[0])
|
||||||
|
|
||||||
|
node_degrees = np.array(adj.sum(1))
|
||||||
|
node_degrees = np.power(node_degrees, -0.5).flatten()
|
||||||
|
node_degrees[np.isinf(node_degrees)] = 0.0
|
||||||
|
node_degrees[np.isnan(node_degrees)] = 0.0
|
||||||
|
degree_matrix = sparse.diags(node_degrees, dtype=np.float32)
|
||||||
|
|
||||||
|
adj = degree_matrix @ adj @ degree_matrix
|
||||||
|
return adj
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(config):
|
||||||
|
"""Loads the Cora graph data into MLX array format."""
|
||||||
|
print("Loading Cora dataset...")
|
||||||
|
|
||||||
|
# Graph nodes
|
||||||
|
raw_nodes_data = np.genfromtxt(config.nodes_path, dtype="str")
|
||||||
|
raw_node_ids = raw_nodes_data[:, 0].astype(
|
||||||
|
"int32"
|
||||||
|
) # unique identifier of each node
|
||||||
|
raw_node_labels = raw_nodes_data[:, -1]
|
||||||
|
labels_enumerated = enumerate_labels(raw_node_labels) # target labels as integers
|
||||||
|
node_features = sparse.csr_matrix(raw_nodes_data[:, 1:-1], dtype="float32")
|
||||||
|
|
||||||
|
# Edges
|
||||||
|
ids_ordered = {raw_id: order for order, raw_id in enumerate(raw_node_ids)}
|
||||||
|
raw_edges_data = np.genfromtxt(config.edges_path, dtype="int32")
|
||||||
|
edges_ordered = np.array(
|
||||||
|
list(map(ids_ordered.get, raw_edges_data.flatten())), dtype="int32"
|
||||||
|
).reshape(raw_edges_data.shape)
|
||||||
|
|
||||||
|
# Adjacency matrix
|
||||||
|
adj = sparse.coo_matrix(
|
||||||
|
(np.ones(edges_ordered.shape[0]), (edges_ordered[:, 0], edges_ordered[:, 1])),
|
||||||
|
shape=(labels_enumerated.shape[0], labels_enumerated.shape[0]),
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make the adjacency matrix symmetric
|
||||||
|
adj = adj + adj.T.multiply(adj.T > adj)
|
||||||
|
adj = normalize_adjacency(adj)
|
||||||
|
|
||||||
|
# Convert to mlx array
|
||||||
|
features = mx.array(node_features.toarray(), mx.float32)
|
||||||
|
labels = mx.array(labels_enumerated, mx.int32)
|
||||||
|
adj = mx.array(adj.toarray())
|
||||||
|
|
||||||
|
print("Dataset loaded.")
|
||||||
|
return features, labels, adj
|
31
gcn/gcn.py
Normal file
31
gcn/gcn.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class GCNLayer(nn.Module):
|
||||||
|
def __init__(self, in_features, out_features, bias=True):
|
||||||
|
super(GCNLayer, self).__init__()
|
||||||
|
self.linear = nn.Linear(in_features, out_features, bias)
|
||||||
|
|
||||||
|
def __call__(self, x, adj):
|
||||||
|
x = self.linear(x)
|
||||||
|
return adj @ x
|
||||||
|
|
||||||
|
|
||||||
|
class GCN(nn.Module):
|
||||||
|
def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):
|
||||||
|
super(GCN, self).__init__()
|
||||||
|
|
||||||
|
layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]
|
||||||
|
self.gcn_layers = [
|
||||||
|
GCNLayer(in_dim, out_dim, bias)
|
||||||
|
for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||||
|
]
|
||||||
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
|
|
||||||
|
def __call__(self, x, adj):
|
||||||
|
for layer in self.gcn_layers[:-1]:
|
||||||
|
x = nn.relu(layer(x, adj))
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
x = self.gcn_layers[-1](x, adj)
|
||||||
|
return x
|
120
gcn/main.py
Normal file
120
gcn/main.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx.optimizers as optim
|
||||||
|
from mlx.nn.losses import cross_entropy
|
||||||
|
|
||||||
|
from datasets import download_cora, load_data, train_val_test_mask
|
||||||
|
from gcn import GCN
|
||||||
|
|
||||||
|
|
||||||
|
def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
|
||||||
|
l = mx.mean(nn.losses.cross_entropy(y_hat, y))
|
||||||
|
|
||||||
|
if weight_decay != 0.0:
|
||||||
|
assert parameters != None, "Model parameters missing for L2 reg."
|
||||||
|
|
||||||
|
l2_reg = mx.zeros(
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
for k1, v1 in parameters.items():
|
||||||
|
for k2, v2 in v1.items():
|
||||||
|
l2_reg += mx.sum(v2["weight"] ** 2)
|
||||||
|
return l + weight_decay * l2_reg.item()
|
||||||
|
return l
|
||||||
|
|
||||||
|
|
||||||
|
def eval_fn(x, y):
|
||||||
|
return mx.mean(mx.argmax(x, axis=1) == y)
|
||||||
|
|
||||||
|
|
||||||
|
def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
|
||||||
|
y_hat = gcn(x, adj)
|
||||||
|
loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters())
|
||||||
|
return loss, y_hat
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
|
||||||
|
# Data loading
|
||||||
|
download_cora()
|
||||||
|
|
||||||
|
x, y, adj = load_data(args)
|
||||||
|
train_mask, val_mask, test_mask = train_val_test_mask(y, args.nb_classes)
|
||||||
|
|
||||||
|
gcn = GCN(
|
||||||
|
x_dim=x.shape[-1],
|
||||||
|
h_dim=args.hidden_dim,
|
||||||
|
out_dim=args.nb_classes,
|
||||||
|
nb_layers=args.nb_layers,
|
||||||
|
dropout=args.dropout,
|
||||||
|
bias=args.bias,
|
||||||
|
)
|
||||||
|
mx.eval(gcn.parameters())
|
||||||
|
|
||||||
|
optimizer = optim.Adam(learning_rate=args.lr)
|
||||||
|
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
|
||||||
|
|
||||||
|
best_val_loss = float("inf")
|
||||||
|
cnt = 0
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
|
||||||
|
# Loss
|
||||||
|
(loss, y_hat), grads = loss_and_grad_fn(
|
||||||
|
gcn, x, adj, y, train_mask, args.weight_decay
|
||||||
|
)
|
||||||
|
optimizer.update(gcn, grads)
|
||||||
|
mx.eval(gcn.parameters(), optimizer.state)
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
val_loss = loss_fn(y_hat[val_mask], y[val_mask])
|
||||||
|
val_acc = eval_fn(y_hat[val_mask], y[val_mask])
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
if val_loss < best_val_loss:
|
||||||
|
best_val_loss = val_loss
|
||||||
|
cnt = 0
|
||||||
|
else:
|
||||||
|
cnt += 1
|
||||||
|
if cnt == args.patience:
|
||||||
|
break
|
||||||
|
|
||||||
|
print(
|
||||||
|
" | ".join(
|
||||||
|
[
|
||||||
|
f"Epoch: {epoch:3d}",
|
||||||
|
f"Train loss: {loss.item():.3f}",
|
||||||
|
f"Val loss: {val_loss.item():.3f}",
|
||||||
|
f"Val acc: {val_acc.item():.2f}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test
|
||||||
|
test_y_hat = gcn(x, adj)
|
||||||
|
test_loss = loss_fn(y_hat[test_mask], y[test_mask])
|
||||||
|
test_acc = eval_fn(y_hat[test_mask], y[test_mask])
|
||||||
|
|
||||||
|
print(f"Test loss: {test_loss.item():.3f} | Test acc: {test_acc.item():.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--nodes_path", type=str, default="cora/cora.content")
|
||||||
|
parser.add_argument("--edges_path", type=str, default="cora/cora.cites")
|
||||||
|
parser.add_argument("--hidden_dim", type=int, default=20)
|
||||||
|
parser.add_argument("--dropout", type=float, default=0.5)
|
||||||
|
parser.add_argument("--nb_layers", type=int, default=2)
|
||||||
|
parser.add_argument("--nb_classes", type=int, default=7)
|
||||||
|
parser.add_argument("--bias", type=bool, default=True)
|
||||||
|
parser.add_argument("--lr", type=float, default=0.001)
|
||||||
|
parser.add_argument("--weight_decay", type=float, default=0.0)
|
||||||
|
parser.add_argument("--patience", type=int, default=20)
|
||||||
|
parser.add_argument("--epochs", type=int, default=100)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
4
gcn/requirements.txt
Normal file
4
gcn/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
mlx==0.0.4
|
||||||
|
numpy==1.26.2
|
||||||
|
scipy==1.11.4
|
||||||
|
requests==2.31.0
|
Loading…
Reference in New Issue
Block a user