mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 02:28:13 +08:00 
			
		
		
		
	
		
			
	
	
		
			53 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			53 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								# Copyright © 2024 Apple Inc.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import mlx.core as mx
							 | 
						||
| 
								 | 
							
								import mlx.nn as nn
							 | 
						||
| 
								 | 
							
								import mlx.utils
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class MLP(nn.Module):
							 | 
						||
| 
								 | 
							
								    """A simple MLP."""
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __init__(
							 | 
						||
| 
								 | 
							
								        self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
							 | 
						||
| 
								 | 
							
								    ):
							 | 
						||
| 
								 | 
							
								        super().__init__()
							 | 
						||
| 
								 | 
							
								        layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
							 | 
						||
| 
								 | 
							
								        self.layers = [
							 | 
						||
| 
								 | 
							
								            nn.Linear(idim, odim)
							 | 
						||
| 
								 | 
							
								            for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
							 | 
						||
| 
								 | 
							
								        ]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __call__(self, x):
							 | 
						||
| 
								 | 
							
								        for l in self.layers[:-1]:
							 | 
						||
| 
								 | 
							
								            x = nn.relu(l(x))
							 | 
						||
| 
								 | 
							
								        return self.layers[-1](x)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								if __name__ == "__main__":
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    batch_size = 8
							 | 
						||
| 
								 | 
							
								    input_dim = 32
							 | 
						||
| 
								 | 
							
								    output_dim = 10
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # Load the model
							 | 
						||
| 
								 | 
							
								    mx.random.seed(0)  # Seed for params
							 | 
						||
| 
								 | 
							
								    model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim)
							 | 
						||
| 
								 | 
							
								    mx.eval(model)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # Note, the model parameters are saved in the export function
							 | 
						||
| 
								 | 
							
								    def forward(x):
							 | 
						||
| 
								 | 
							
								        return model(x)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    mx.random.seed(42)  # Seed for input
							 | 
						||
| 
								 | 
							
								    example_x = mx.random.uniform(shape=(batch_size, input_dim))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    mx.export_function("eval_mlp.mlxfn", forward, example_x)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # Import in Python
							 | 
						||
| 
								 | 
							
								    imported_forward = mx.import_function("eval_mlp.mlxfn")
							 | 
						||
| 
								 | 
							
								    expected = forward(example_x)
							 | 
						||
| 
								 | 
							
								    (out,) = imported_forward(example_x)
							 | 
						||
| 
								 | 
							
								    assert mx.allclose(expected, out)
							 | 
						||
| 
								 | 
							
								    print(out)
							 |