mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 19:18:09 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			32 lines
		
	
	
		
			915 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			32 lines
		
	
	
		
			915 B
		
	
	
	
		
			Python
		
	
	
	
	
	
| import mlx.nn as nn
 | |
| 
 | |
| 
 | |
| class GCNLayer(nn.Module):
 | |
|     def __init__(self, in_features, out_features, bias=True):
 | |
|         super(GCNLayer, self).__init__()
 | |
|         self.linear = nn.Linear(in_features, out_features, bias)
 | |
| 
 | |
|     def __call__(self, x, adj):
 | |
|         x = self.linear(x)
 | |
|         return adj @ x
 | |
| 
 | |
| 
 | |
| class GCN(nn.Module):
 | |
|     def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):
 | |
|         super(GCN, self).__init__()
 | |
| 
 | |
|         layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]
 | |
|         self.gcn_layers = [
 | |
|             GCNLayer(in_dim, out_dim, bias)
 | |
|             for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])
 | |
|         ]
 | |
|         self.dropout = nn.Dropout(p=dropout)
 | |
| 
 | |
|     def __call__(self, x, adj):
 | |
|         for layer in self.gcn_layers[:-1]:
 | |
|             x = nn.relu(layer(x, adj))
 | |
|             x = self.dropout(x)
 | |
| 
 | |
|         x = self.gcn_layers[-1](x, adj)
 | |
|         return x
 | 
