mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00

* export and import functions * refactor + works for few primitives * nit * allow primitives with state * nit * nit * simplify serialize / deserialize * fix for constants * python bindings * maybe fix serialize failure case * add example * more primitives, training kind of works * same result for python and c++ * some fixes * fix export * template it up * some simplificatoin * rebase * allow kwargs and multiple functions * exporter * more primitives for exporting * deal with endianness * handle invalid stream * add docstring
77 lines
2.4 KiB
Python
77 lines
2.4 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import mlx.optimizers as optim
|
|
import mlx.utils
|
|
|
|
|
|
class MLP(nn.Module):
|
|
"""A simple MLP."""
|
|
|
|
def __init__(
|
|
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
|
):
|
|
super().__init__()
|
|
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
|
self.layers = [
|
|
nn.Linear(idim, odim)
|
|
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
|
]
|
|
|
|
def __call__(self, x):
|
|
for l in self.layers[:-1]:
|
|
x = nn.relu(l(x))
|
|
return self.layers[-1](x)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
batch_size = 8
|
|
input_dim = 32
|
|
output_dim = 10
|
|
|
|
def init():
|
|
# Seed for the parameter initialization
|
|
mx.random.seed(0)
|
|
model = MLP(
|
|
num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim
|
|
)
|
|
optimizer = optim.SGD(learning_rate=1e-1)
|
|
optimizer.init(model.parameters())
|
|
state = [model.parameters(), optimizer.state]
|
|
tree_structure, state = zip(*mlx.utils.tree_flatten(state))
|
|
return model, optimizer, tree_structure, state
|
|
|
|
# Export the model parameter initialization
|
|
model, optimizer, tree_structure, state = init()
|
|
mx.eval(state)
|
|
mx.export_function("init_mlp.mlxfn", lambda: init()[-1])
|
|
|
|
def loss_fn(params, X, y):
|
|
model.update(params)
|
|
return nn.losses.cross_entropy(model(X), y, reduction="mean")
|
|
|
|
def step(*inputs):
|
|
*state, X, y = inputs
|
|
params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state)))
|
|
optimizer.state = opt_state
|
|
loss, grads = mx.value_and_grad(loss_fn)(params, X, y)
|
|
params = optimizer.apply_gradients(grads, params)
|
|
_, state = zip(*mlx.utils.tree_flatten([params, optimizer.state]))
|
|
return *state, loss
|
|
|
|
# Make some random data
|
|
mx.random.seed(42)
|
|
example_X = mx.random.normal(shape=(batch_size, input_dim))
|
|
example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,))
|
|
mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y)
|
|
|
|
# Export one step of SGD
|
|
imported_step = mx.import_function("train_mlp.mlxfn")
|
|
|
|
for it in range(100):
|
|
*state, loss = imported_step(*state, example_X, example_y)
|
|
if it % 10 == 0:
|
|
print(f"Loss {loss.item():.6}")
|