mlx-examples/gcn/gcn.py
2023-12-11 17:48:07 +01:00

32 lines
915 B
Python

import mlx.nn as nn
class GCNLayer(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super(GCNLayer, self).__init__()
self.linear = nn.Linear(in_features, out_features, bias)
def __call__(self, x, adj):
x = self.linear(x)
return adj @ x
class GCN(nn.Module):
def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):
super(GCN, self).__init__()
layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]
self.gcn_layers = [
GCNLayer(in_dim, out_dim, bias)
for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])
]
self.dropout = nn.Dropout(p=dropout)
def __call__(self, x, adj):
for layer in self.gcn_layers[:-1]:
x = nn.relu(layer(x, adj))
x = self.dropout(x)
x = self.gcn_layers[-1](x, adj)
return x