mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
fix comments before merge
This commit is contained in:
@@ -38,12 +38,12 @@ def download_cora():
|
||||
os.remove(file_path)
|
||||
|
||||
|
||||
def train_val_test_mask(labels, num_classes):
|
||||
def train_val_test_mask():
|
||||
"""Splits the loaded dataset into train/validation/test sets."""
|
||||
|
||||
train_set = mx.array(list(range(140)))
|
||||
validation_set = mx.array(list(range(200, 500)))
|
||||
test_set = mx.array(list(range(500, 1500)))
|
||||
train_set = mx.arange(140)
|
||||
validation_set = mx.arange(200, 500)
|
||||
test_set = mx.arange(500, 1500)
|
||||
|
||||
return train_set, validation_set, test_set
|
||||
|
||||
@@ -52,15 +52,15 @@ def enumerate_labels(labels):
|
||||
"""Converts the labels from the original
|
||||
string form to the integer [0:MaxLabels-1]
|
||||
"""
|
||||
unique = list(set(labels))
|
||||
labels = np.array([unique.index(label) for label in labels])
|
||||
label_map = {v: e for e, v in enumerate(set(labels))}
|
||||
labels = np.array([label_map[label] for label in labels])
|
||||
return labels
|
||||
|
||||
|
||||
def normalize_adjacency(adj):
|
||||
"""Normalizes the adjacency matrix according to the
|
||||
paper by Kipf et al.
|
||||
https://arxiv.org/pdf/1609.02907.pdf
|
||||
https://arxiv.org/abs/1609.02907
|
||||
"""
|
||||
adj = adj + sparse.eye(adj.shape[0])
|
||||
|
||||
@@ -78,6 +78,9 @@ def load_data(config):
|
||||
"""Loads the Cora graph data into MLX array format."""
|
||||
print("Loading Cora dataset...")
|
||||
|
||||
# Download dataset files
|
||||
download_cora()
|
||||
|
||||
# Graph nodes
|
||||
raw_nodes_data = np.genfromtxt(config.nodes_path, dtype="str")
|
||||
raw_node_ids = raw_nodes_data[:, 0].astype(
|
||||
|
Reference in New Issue
Block a user