mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
53 lines
1.3 KiB
Python
53 lines
1.3 KiB
Python
![]() |
# Copyright © 2024 Apple Inc.
|
||
|
|
||
|
import mlx.core as mx
|
||
|
import mlx.nn as nn
|
||
|
import mlx.utils
|
||
|
|
||
|
|
||
|
class MLP(nn.Module):
|
||
|
"""A simple MLP."""
|
||
|
|
||
|
def __init__(
|
||
|
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||
|
):
|
||
|
super().__init__()
|
||
|
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||
|
self.layers = [
|
||
|
nn.Linear(idim, odim)
|
||
|
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||
|
]
|
||
|
|
||
|
def __call__(self, x):
|
||
|
for l in self.layers[:-1]:
|
||
|
x = nn.relu(l(x))
|
||
|
return self.layers[-1](x)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
|
||
|
batch_size = 8
|
||
|
input_dim = 32
|
||
|
output_dim = 10
|
||
|
|
||
|
# Load the model
|
||
|
mx.random.seed(0) # Seed for params
|
||
|
model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim)
|
||
|
mx.eval(model)
|
||
|
|
||
|
# Note, the model parameters are saved in the export function
|
||
|
def forward(x):
|
||
|
return model(x)
|
||
|
|
||
|
mx.random.seed(42) # Seed for input
|
||
|
example_x = mx.random.uniform(shape=(batch_size, input_dim))
|
||
|
|
||
|
mx.export_function("eval_mlp.mlxfn", forward, example_x)
|
||
|
|
||
|
# Import in Python
|
||
|
imported_forward = mx.import_function("eval_mlp.mlxfn")
|
||
|
expected = forward(example_x)
|
||
|
(out,) = imported_forward(example_x)
|
||
|
assert mx.allclose(expected, out)
|
||
|
print(out)
|