fix comments before merge

This commit is contained in:
Tristan Bilot 2023-12-11 23:10:46 +01:00
parent b95e48e146
commit b606bfa6a7
5 changed files with 19 additions and 22 deletions

1
gcn/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
cora/

View File

@ -1,6 +1,6 @@
# Graph Convolutional Network
An example of [GCN](https://arxiv.org/pdf/1609.02907.pdf%EF%BC%89) implementation with MLX.
An example of [GCN](https://arxiv.org/abs/1609.02907) implementation with MLX.
### Install requirements
First, install the few dependencies with `pip`.

View File

@ -38,12 +38,12 @@ def download_cora():
os.remove(file_path)
def train_val_test_mask(labels, num_classes):
def train_val_test_mask():
"""Splits the loaded dataset into train/validation/test sets."""
train_set = mx.array(list(range(140)))
validation_set = mx.array(list(range(200, 500)))
test_set = mx.array(list(range(500, 1500)))
train_set = mx.arange(140)
validation_set = mx.arange(200, 500)
test_set = mx.arange(500, 1500)
return train_set, validation_set, test_set
@ -52,15 +52,15 @@ def enumerate_labels(labels):
"""Converts the labels from the original
string form to the integer [0:MaxLabels-1]
"""
unique = list(set(labels))
labels = np.array([unique.index(label) for label in labels])
label_map = {v: e for e, v in enumerate(set(labels))}
labels = np.array([label_map[label] for label in labels])
return labels
def normalize_adjacency(adj):
"""Normalizes the adjacency matrix according to the
paper by Kipf et al.
https://arxiv.org/pdf/1609.02907.pdf
https://arxiv.org/abs/1609.02907
"""
adj = adj + sparse.eye(adj.shape[0])
@ -78,6 +78,9 @@ def load_data(config):
"""Loads the Cora graph data into MLX array format."""
print("Loading Cora dataset...")
# Download dataset files
download_cora()
# Graph nodes
raw_nodes_data = np.genfromtxt(config.nodes_path, dtype="str")
raw_node_ids = raw_nodes_data[:, 0].astype(

View File

@ -16,13 +16,8 @@ def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
if weight_decay != 0.0:
assert parameters != None, "Model parameters missing for L2 reg."
l2_reg = mx.zeros(
1,
)
for leaf in tree_flatten(parameters):
l2_reg += mx.sum(leaf[1] ** 2)
l2_reg = mx.sqrt(l2_reg)
return l + weight_decay * l2_reg.item()
l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt()
return l + weight_decay * l2_reg
return l
@ -39,10 +34,8 @@ def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
def main(args):
# Data loading
download_cora()
x, y, adj = load_data(args)
train_mask, val_mask, test_mask = train_val_test_mask(y, args.nb_classes)
train_mask, val_mask, test_mask = train_val_test_mask()
gcn = GCN(
x_dim=x.shape[-1],

View File

@ -1,4 +1,4 @@
mlx==0.0.4
numpy==1.26.2
scipy==1.11.4
requests==2.31.0
mlx>=0.0.4
numpy>=1.26.2
scipy>=1.11.4
requests>=2.31.0