mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-02 05:14:40 +08:00
Export / import functions to / from a file (#1642)
* export and import functions * refactor + works for few primitives * nit * allow primitives with state * nit * nit * simplify serialize / deserialize * fix for constants * python bindings * maybe fix serialize failure case * add example * more primitives, training kind of works * same result for python and c++ * some fixes * fix export * template it up * some simplificatoin * rebase * allow kwargs and multiple functions * exporter * more primitives for exporting * deal with endianness * handle invalid stream * add docstring
This commit is contained in:
27
examples/export/CMakeLists.txt
Normal file
27
examples/export/CMakeLists.txt
Normal file
@@ -0,0 +1,27 @@
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(import_mlx LANGUAGES CXX)
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
# ----------------------------- Dependencies -----------------------------
|
||||
find_package(
|
||||
Python 3.9
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m pip show mlx
|
||||
COMMAND grep location
|
||||
COMMAND awk "{print $4 \"/mlx\"}"
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
add_executable(eval_mlp eval_mlp.cpp)
|
||||
target_link_libraries(eval_mlp PRIVATE mlx)
|
||||
|
||||
add_executable(train_mlp train_mlp.cpp)
|
||||
target_link_libraries(train_mlp PRIVATE mlx)
|
49
examples/export/README.md
Normal file
49
examples/export/README.md
Normal file
@@ -0,0 +1,49 @@
|
||||
## Setup
|
||||
|
||||
Install MLX:
|
||||
|
||||
```bash
|
||||
pip install mlx>=0.22
|
||||
```
|
||||
|
||||
Build the C++ examples:
|
||||
|
||||
```bash
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
### Eval MLP
|
||||
|
||||
Run the Python script to export the eval function:
|
||||
|
||||
```bash
|
||||
python eval_mlp.py
|
||||
```
|
||||
|
||||
Then run the C++ program to import and run the function:
|
||||
|
||||
```
|
||||
./build/eval_mlp
|
||||
```
|
||||
|
||||
The Python and C++ programs should output the same result.
|
||||
|
||||
### Train MLP
|
||||
|
||||
Run the Python script to export the model initialization and training
|
||||
functions:
|
||||
|
||||
```bash
|
||||
python train_mlp.py
|
||||
```
|
||||
|
||||
Then run the C++ program to import and run the functions:
|
||||
|
||||
```
|
||||
./build/train_mlp
|
||||
```
|
||||
|
||||
The Python and C++ programs should output the same results.
|
25
examples/export/eval_mlp.cpp
Normal file
25
examples/export/eval_mlp.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <mlx/mlx.h>
|
||||
#include <iostream>
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
int batch_size = 8;
|
||||
int input_dim = 32;
|
||||
|
||||
// Make the input
|
||||
random::seed(42);
|
||||
auto example_x = random::uniform({batch_size, input_dim});
|
||||
|
||||
// Import the function
|
||||
auto forward = import_function("eval_mlp.mlxfn");
|
||||
|
||||
// Call the imported function
|
||||
auto out = forward({example_x})[0];
|
||||
|
||||
std::cout << out << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
52
examples/export/eval_mlp.py
Normal file
52
examples/export/eval_mlp.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.utils
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""A simple MLP."""
|
||||
|
||||
def __init__(
|
||||
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||
):
|
||||
super().__init__()
|
||||
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||
self.layers = [
|
||||
nn.Linear(idim, odim)
|
||||
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = nn.relu(l(x))
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
batch_size = 8
|
||||
input_dim = 32
|
||||
output_dim = 10
|
||||
|
||||
# Load the model
|
||||
mx.random.seed(0) # Seed for params
|
||||
model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim)
|
||||
mx.eval(model)
|
||||
|
||||
# Note, the model parameters are saved in the export function
|
||||
def forward(x):
|
||||
return model(x)
|
||||
|
||||
mx.random.seed(42) # Seed for input
|
||||
example_x = mx.random.uniform(shape=(batch_size, input_dim))
|
||||
|
||||
mx.export_function("eval_mlp.mlxfn", forward, example_x)
|
||||
|
||||
# Import in Python
|
||||
imported_forward = mx.import_function("eval_mlp.mlxfn")
|
||||
expected = forward(example_x)
|
||||
(out,) = imported_forward(example_x)
|
||||
assert mx.allclose(expected, out)
|
||||
print(out)
|
35
examples/export/train_mlp.cpp
Normal file
35
examples/export/train_mlp.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <mlx/mlx.h>
|
||||
#include <iostream>
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
int batch_size = 8;
|
||||
int input_dim = 32;
|
||||
int output_dim = 10;
|
||||
|
||||
auto state = import_function("init_mlp.mlxfn")({});
|
||||
|
||||
// Make the input
|
||||
random::seed(42);
|
||||
auto example_X = random::normal({batch_size, input_dim});
|
||||
auto example_y = random::randint(0, output_dim, {batch_size});
|
||||
|
||||
// Import the function
|
||||
auto step = import_function("train_mlp.mlxfn");
|
||||
|
||||
// Call the imported function
|
||||
for (int it = 0; it < 100; ++it) {
|
||||
state.insert(state.end(), {example_X, example_y});
|
||||
state = step(state);
|
||||
eval(state);
|
||||
auto loss = state.back();
|
||||
state.pop_back();
|
||||
if (it % 10 == 0) {
|
||||
std::cout << "Loss " << loss.item<float>() << std::endl;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
76
examples/export/train_mlp.py
Normal file
76
examples/export/train_mlp.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
import mlx.utils
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""A simple MLP."""
|
||||
|
||||
def __init__(
|
||||
self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
|
||||
):
|
||||
super().__init__()
|
||||
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
|
||||
self.layers = [
|
||||
nn.Linear(idim, odim)
|
||||
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers[:-1]:
|
||||
x = nn.relu(l(x))
|
||||
return self.layers[-1](x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
batch_size = 8
|
||||
input_dim = 32
|
||||
output_dim = 10
|
||||
|
||||
def init():
|
||||
# Seed for the parameter initialization
|
||||
mx.random.seed(0)
|
||||
model = MLP(
|
||||
num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim
|
||||
)
|
||||
optimizer = optim.SGD(learning_rate=1e-1)
|
||||
optimizer.init(model.parameters())
|
||||
state = [model.parameters(), optimizer.state]
|
||||
tree_structure, state = zip(*mlx.utils.tree_flatten(state))
|
||||
return model, optimizer, tree_structure, state
|
||||
|
||||
# Export the model parameter initialization
|
||||
model, optimizer, tree_structure, state = init()
|
||||
mx.eval(state)
|
||||
mx.export_function("init_mlp.mlxfn", lambda: init()[-1])
|
||||
|
||||
def loss_fn(params, X, y):
|
||||
model.update(params)
|
||||
return nn.losses.cross_entropy(model(X), y, reduction="mean")
|
||||
|
||||
def step(*inputs):
|
||||
*state, X, y = inputs
|
||||
params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state)))
|
||||
optimizer.state = opt_state
|
||||
loss, grads = mx.value_and_grad(loss_fn)(params, X, y)
|
||||
params = optimizer.apply_gradients(grads, params)
|
||||
_, state = zip(*mlx.utils.tree_flatten([params, optimizer.state]))
|
||||
return *state, loss
|
||||
|
||||
# Make some random data
|
||||
mx.random.seed(42)
|
||||
example_X = mx.random.normal(shape=(batch_size, input_dim))
|
||||
example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,))
|
||||
mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y)
|
||||
|
||||
# Export one step of SGD
|
||||
imported_step = mx.import_function("train_mlp.mlxfn")
|
||||
|
||||
for it in range(100):
|
||||
*state, loss = imported_step(*state, example_X, example_y)
|
||||
if it % 10 == 0:
|
||||
print(f"Loss {loss.item():.6}")
|
Reference in New Issue
Block a user