mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
fix comments before merge
This commit is contained in:
parent
b95e48e146
commit
b606bfa6a7
1
gcn/.gitignore
vendored
Normal file
1
gcn/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
cora/
|
@ -1,6 +1,6 @@
|
|||||||
# Graph Convolutional Network
|
# Graph Convolutional Network
|
||||||
|
|
||||||
An example of [GCN](https://arxiv.org/pdf/1609.02907.pdf%EF%BC%89) implementation with MLX.
|
An example of [GCN](https://arxiv.org/abs/1609.02907) implementation with MLX.
|
||||||
|
|
||||||
### Install requirements
|
### Install requirements
|
||||||
First, install the few dependencies with `pip`.
|
First, install the few dependencies with `pip`.
|
||||||
|
@ -38,12 +38,12 @@ def download_cora():
|
|||||||
os.remove(file_path)
|
os.remove(file_path)
|
||||||
|
|
||||||
|
|
||||||
def train_val_test_mask(labels, num_classes):
|
def train_val_test_mask():
|
||||||
"""Splits the loaded dataset into train/validation/test sets."""
|
"""Splits the loaded dataset into train/validation/test sets."""
|
||||||
|
|
||||||
train_set = mx.array(list(range(140)))
|
train_set = mx.arange(140)
|
||||||
validation_set = mx.array(list(range(200, 500)))
|
validation_set = mx.arange(200, 500)
|
||||||
test_set = mx.array(list(range(500, 1500)))
|
test_set = mx.arange(500, 1500)
|
||||||
|
|
||||||
return train_set, validation_set, test_set
|
return train_set, validation_set, test_set
|
||||||
|
|
||||||
@ -52,15 +52,15 @@ def enumerate_labels(labels):
|
|||||||
"""Converts the labels from the original
|
"""Converts the labels from the original
|
||||||
string form to the integer [0:MaxLabels-1]
|
string form to the integer [0:MaxLabels-1]
|
||||||
"""
|
"""
|
||||||
unique = list(set(labels))
|
label_map = {v: e for e, v in enumerate(set(labels))}
|
||||||
labels = np.array([unique.index(label) for label in labels])
|
labels = np.array([label_map[label] for label in labels])
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
def normalize_adjacency(adj):
|
def normalize_adjacency(adj):
|
||||||
"""Normalizes the adjacency matrix according to the
|
"""Normalizes the adjacency matrix according to the
|
||||||
paper by Kipf et al.
|
paper by Kipf et al.
|
||||||
https://arxiv.org/pdf/1609.02907.pdf
|
https://arxiv.org/abs/1609.02907
|
||||||
"""
|
"""
|
||||||
adj = adj + sparse.eye(adj.shape[0])
|
adj = adj + sparse.eye(adj.shape[0])
|
||||||
|
|
||||||
@ -78,6 +78,9 @@ def load_data(config):
|
|||||||
"""Loads the Cora graph data into MLX array format."""
|
"""Loads the Cora graph data into MLX array format."""
|
||||||
print("Loading Cora dataset...")
|
print("Loading Cora dataset...")
|
||||||
|
|
||||||
|
# Download dataset files
|
||||||
|
download_cora()
|
||||||
|
|
||||||
# Graph nodes
|
# Graph nodes
|
||||||
raw_nodes_data = np.genfromtxt(config.nodes_path, dtype="str")
|
raw_nodes_data = np.genfromtxt(config.nodes_path, dtype="str")
|
||||||
raw_node_ids = raw_nodes_data[:, 0].astype(
|
raw_node_ids = raw_nodes_data[:, 0].astype(
|
||||||
|
13
gcn/main.py
13
gcn/main.py
@ -16,13 +16,8 @@ def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
|
|||||||
if weight_decay != 0.0:
|
if weight_decay != 0.0:
|
||||||
assert parameters != None, "Model parameters missing for L2 reg."
|
assert parameters != None, "Model parameters missing for L2 reg."
|
||||||
|
|
||||||
l2_reg = mx.zeros(
|
l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt()
|
||||||
1,
|
return l + weight_decay * l2_reg
|
||||||
)
|
|
||||||
for leaf in tree_flatten(parameters):
|
|
||||||
l2_reg += mx.sum(leaf[1] ** 2)
|
|
||||||
l2_reg = mx.sqrt(l2_reg)
|
|
||||||
return l + weight_decay * l2_reg.item()
|
|
||||||
return l
|
return l
|
||||||
|
|
||||||
|
|
||||||
@ -39,10 +34,8 @@ def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
|
|||||||
def main(args):
|
def main(args):
|
||||||
|
|
||||||
# Data loading
|
# Data loading
|
||||||
download_cora()
|
|
||||||
|
|
||||||
x, y, adj = load_data(args)
|
x, y, adj = load_data(args)
|
||||||
train_mask, val_mask, test_mask = train_val_test_mask(y, args.nb_classes)
|
train_mask, val_mask, test_mask = train_val_test_mask()
|
||||||
|
|
||||||
gcn = GCN(
|
gcn = GCN(
|
||||||
x_dim=x.shape[-1],
|
x_dim=x.shape[-1],
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx==0.0.4
|
mlx>=0.0.4
|
||||||
numpy==1.26.2
|
numpy>=1.26.2
|
||||||
scipy==1.11.4
|
scipy>=1.11.4
|
||||||
requests==2.31.0
|
requests>=2.31.0
|
||||||
|
Loading…
Reference in New Issue
Block a user