mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
fix comments before merge
This commit is contained in:
parent
b95e48e146
commit
b606bfa6a7
1
gcn/.gitignore
vendored
Normal file
1
gcn/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
cora/
|
@ -1,6 +1,6 @@
|
||||
# Graph Convolutional Network
|
||||
|
||||
An example of [GCN](https://arxiv.org/pdf/1609.02907.pdf%EF%BC%89) implementation with MLX.
|
||||
An example of [GCN](https://arxiv.org/abs/1609.02907) implementation with MLX.
|
||||
|
||||
### Install requirements
|
||||
First, install the few dependencies with `pip`.
|
||||
|
@ -38,12 +38,12 @@ def download_cora():
|
||||
os.remove(file_path)
|
||||
|
||||
|
||||
def train_val_test_mask(labels, num_classes):
|
||||
def train_val_test_mask():
|
||||
"""Splits the loaded dataset into train/validation/test sets."""
|
||||
|
||||
train_set = mx.array(list(range(140)))
|
||||
validation_set = mx.array(list(range(200, 500)))
|
||||
test_set = mx.array(list(range(500, 1500)))
|
||||
train_set = mx.arange(140)
|
||||
validation_set = mx.arange(200, 500)
|
||||
test_set = mx.arange(500, 1500)
|
||||
|
||||
return train_set, validation_set, test_set
|
||||
|
||||
@ -52,15 +52,15 @@ def enumerate_labels(labels):
|
||||
"""Converts the labels from the original
|
||||
string form to the integer [0:MaxLabels-1]
|
||||
"""
|
||||
unique = list(set(labels))
|
||||
labels = np.array([unique.index(label) for label in labels])
|
||||
label_map = {v: e for e, v in enumerate(set(labels))}
|
||||
labels = np.array([label_map[label] for label in labels])
|
||||
return labels
|
||||
|
||||
|
||||
def normalize_adjacency(adj):
|
||||
"""Normalizes the adjacency matrix according to the
|
||||
paper by Kipf et al.
|
||||
https://arxiv.org/pdf/1609.02907.pdf
|
||||
https://arxiv.org/abs/1609.02907
|
||||
"""
|
||||
adj = adj + sparse.eye(adj.shape[0])
|
||||
|
||||
@ -78,6 +78,9 @@ def load_data(config):
|
||||
"""Loads the Cora graph data into MLX array format."""
|
||||
print("Loading Cora dataset...")
|
||||
|
||||
# Download dataset files
|
||||
download_cora()
|
||||
|
||||
# Graph nodes
|
||||
raw_nodes_data = np.genfromtxt(config.nodes_path, dtype="str")
|
||||
raw_node_ids = raw_nodes_data[:, 0].astype(
|
||||
|
13
gcn/main.py
13
gcn/main.py
@ -16,13 +16,8 @@ def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
|
||||
if weight_decay != 0.0:
|
||||
assert parameters != None, "Model parameters missing for L2 reg."
|
||||
|
||||
l2_reg = mx.zeros(
|
||||
1,
|
||||
)
|
||||
for leaf in tree_flatten(parameters):
|
||||
l2_reg += mx.sum(leaf[1] ** 2)
|
||||
l2_reg = mx.sqrt(l2_reg)
|
||||
return l + weight_decay * l2_reg.item()
|
||||
l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt()
|
||||
return l + weight_decay * l2_reg
|
||||
return l
|
||||
|
||||
|
||||
@ -39,10 +34,8 @@ def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
|
||||
def main(args):
|
||||
|
||||
# Data loading
|
||||
download_cora()
|
||||
|
||||
x, y, adj = load_data(args)
|
||||
train_mask, val_mask, test_mask = train_val_test_mask(y, args.nb_classes)
|
||||
train_mask, val_mask, test_mask = train_val_test_mask()
|
||||
|
||||
gcn = GCN(
|
||||
x_dim=x.shape[-1],
|
||||
|
@ -1,4 +1,4 @@
|
||||
mlx==0.0.4
|
||||
numpy==1.26.2
|
||||
scipy==1.11.4
|
||||
requests==2.31.0
|
||||
mlx>=0.0.4
|
||||
numpy>=1.26.2
|
||||
scipy>=1.11.4
|
||||
requests>=2.31.0
|
||||
|
Loading…
Reference in New Issue
Block a user