mlx/examples/export/train_mlp.py
Awni Hannun 4ba0c24a8f
Export / import functions to / from a file (#1642)
* export and import functions

* refactor + works for few primitives

* nit

* allow primitives with state

* nit

* nit

* simplify serialize / deserialize

* fix for constants

* python bindings

* maybe fix serialize failure case

* add example

* more primitives, training kind of works

* same result for python and c++

* some fixes

* fix export

* template it up

* some simplificatoin

* rebase

* allow kwargs and multiple functions

* exporter

* more primitives for exporting

* deal with endianness

* handle invalid stream

* add docstring
2024-12-24 11:19:13 -08:00

77 lines
2.4 KiB
Python

# Copyright © 2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.utils
class MLP(nn.Module):
"""A simple MLP."""
def __init__(
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
def __call__(self, x):
for l in self.layers[:-1]:
x = nn.relu(l(x))
return self.layers[-1](x)
if __name__ == "__main__":
batch_size = 8
input_dim = 32
output_dim = 10
def init():
# Seed for the parameter initialization
mx.random.seed(0)
model = MLP(
num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim
)
optimizer = optim.SGD(learning_rate=1e-1)
optimizer.init(model.parameters())
state = [model.parameters(), optimizer.state]
tree_structure, state = zip(*mlx.utils.tree_flatten(state))
return model, optimizer, tree_structure, state
# Export the model parameter initialization
model, optimizer, tree_structure, state = init()
mx.eval(state)
mx.export_function("init_mlp.mlxfn", lambda: init()[-1])
def loss_fn(params, X, y):
model.update(params)
return nn.losses.cross_entropy(model(X), y, reduction="mean")
def step(*inputs):
*state, X, y = inputs
params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state)))
optimizer.state = opt_state
loss, grads = mx.value_and_grad(loss_fn)(params, X, y)
params = optimizer.apply_gradients(grads, params)
_, state = zip(*mlx.utils.tree_flatten([params, optimizer.state]))
return *state, loss
# Make some random data
mx.random.seed(42)
example_X = mx.random.normal(shape=(batch_size, input_dim))
example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,))
mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y)
# Export one step of SGD
imported_step = mx.import_function("train_mlp.mlxfn")
for it in range(100):
*state, loss = imported_step(*state, example_X, example_y)
if it % 10 == 0:
print(f"Loss {loss.item():.6}")