fix comments before merge

This commit is contained in:
Tristan Bilot
2023-12-11 23:10:46 +01:00
parent b95e48e146
commit b606bfa6a7
5 changed files with 19 additions and 22 deletions

View File

@@ -38,12 +38,12 @@ def download_cora():
os.remove(file_path)
def train_val_test_mask(labels, num_classes):
def train_val_test_mask():
"""Splits the loaded dataset into train/validation/test sets."""
train_set = mx.array(list(range(140)))
validation_set = mx.array(list(range(200, 500)))
test_set = mx.array(list(range(500, 1500)))
train_set = mx.arange(140)
validation_set = mx.arange(200, 500)
test_set = mx.arange(500, 1500)
return train_set, validation_set, test_set
@@ -52,15 +52,15 @@ def enumerate_labels(labels):
"""Converts the labels from the original
string form to the integer [0:MaxLabels-1]
"""
unique = list(set(labels))
labels = np.array([unique.index(label) for label in labels])
label_map = {v: e for e, v in enumerate(set(labels))}
labels = np.array([label_map[label] for label in labels])
return labels
def normalize_adjacency(adj):
"""Normalizes the adjacency matrix according to the
paper by Kipf et al.
https://arxiv.org/pdf/1609.02907.pdf
https://arxiv.org/abs/1609.02907
"""
adj = adj + sparse.eye(adj.shape[0])
@@ -78,6 +78,9 @@ def load_data(config):
"""Loads the Cora graph data into MLX array format."""
print("Loading Cora dataset...")
# Download dataset files
download_cora()
# Graph nodes
raw_nodes_data = np.genfromtxt(config.nodes_path, dtype="str")
raw_node_ids = raw_nodes_data[:, 0].astype(