mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
add GCN implementation
This commit is contained in:
31
gcn/gcn.py
Normal file
31
gcn/gcn.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class GCNLayer(nn.Module):
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
super(GCNLayer, self).__init__()
|
||||
self.linear = nn.Linear(in_features, out_features, bias)
|
||||
|
||||
def __call__(self, x, adj):
|
||||
x = self.linear(x)
|
||||
return adj @ x
|
||||
|
||||
|
||||
class GCN(nn.Module):
|
||||
def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):
|
||||
super(GCN, self).__init__()
|
||||
|
||||
layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]
|
||||
self.gcn_layers = [
|
||||
GCNLayer(in_dim, out_dim, bias)
|
||||
for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||
]
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def __call__(self, x, adj):
|
||||
for layer in self.gcn_layers[:-1]:
|
||||
x = nn.relu(layer(x, adj))
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.gcn_layers[-1](x, adj)
|
||||
return x
|
Reference in New Issue
Block a user