mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
32 lines
915 B
Python
32 lines
915 B
Python
import mlx.nn as nn
|
|
|
|
|
|
class GCNLayer(nn.Module):
|
|
def __init__(self, in_features, out_features, bias=True):
|
|
super(GCNLayer, self).__init__()
|
|
self.linear = nn.Linear(in_features, out_features, bias)
|
|
|
|
def __call__(self, x, adj):
|
|
x = self.linear(x)
|
|
return adj @ x
|
|
|
|
|
|
class GCN(nn.Module):
|
|
def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):
|
|
super(GCN, self).__init__()
|
|
|
|
layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]
|
|
self.gcn_layers = [
|
|
GCNLayer(in_dim, out_dim, bias)
|
|
for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])
|
|
]
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
|
|
def __call__(self, x, adj):
|
|
for layer in self.gcn_layers[:-1]:
|
|
x = nn.relu(layer(x, adj))
|
|
x = self.dropout(x)
|
|
|
|
x = self.gcn_layers[-1](x, adj)
|
|
return x
|