# Copyright © 2024 Apple Inc. import mlx.core as mx import mlx.nn as nn import mlx.utils class MLP(nn.Module): """A simple MLP.""" def __init__( self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int ): super().__init__() layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim] self.layers = [ nn.Linear(idim, odim) for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]) ] def __call__(self, x): for l in self.layers[:-1]: x = nn.relu(l(x)) return self.layers[-1](x) if __name__ == "__main__": batch_size = 8 input_dim = 32 output_dim = 10 # Load the model mx.random.seed(0) # Seed for params model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim) mx.eval(model) # Note, the model parameters are saved in the export function def forward(x): return model(x) mx.random.seed(42) # Seed for input example_x = mx.random.uniform(shape=(batch_size, input_dim)) mx.export_function("eval_mlp.mlxfn", forward, example_x) # Import in Python imported_forward = mx.import_function("eval_mlp.mlxfn") expected = forward(example_x) (out,) = imported_forward(example_x) assert mx.allclose(expected, out) print(out)