diff --git a/CMakeLists.txt b/CMakeLists.txt index 78bae582b..662f122e3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,6 +27,7 @@ option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) if(NOT MLX_VERSION) set(MLX_VERSION 0.21.1) endif() +add_compile_definitions("MLX_VERSION=${MLX_VERSION}") # --------------------- Processor tests ------------------------- diff --git a/docs/src/index.rst b/docs/src/index.rst index 1e5e6ad8a..4c41800d7 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -61,6 +61,7 @@ are the CPU and GPU. python/array python/data_types python/devices_and_streams + python/export python/ops python/random python/transforms diff --git a/docs/src/python/export.rst b/docs/src/python/export.rst new file mode 100644 index 000000000..9a1599096 --- /dev/null +++ b/docs/src/python/export.rst @@ -0,0 +1,14 @@ +.. _export: + +Export Functions +================ + +.. currentmodule:: mlx.core + +.. autosummary:: + :toctree: _autosummary + + export_function + import_function + exporter + export_to_dot diff --git a/examples/export/CMakeLists.txt b/examples/export/CMakeLists.txt new file mode 100644 index 000000000..c30011406 --- /dev/null +++ b/examples/export/CMakeLists.txt @@ -0,0 +1,27 @@ +cmake_minimum_required(VERSION 3.27) + +project(import_mlx LANGUAGES CXX) + +# ----------------------------- Setup ----------------------------- +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +# ----------------------------- Dependencies ----------------------------- +find_package( + Python 3.9 + COMPONENTS Interpreter Development.Module + REQUIRED) +execute_process( + COMMAND "${Python_EXECUTABLE}" -m pip show mlx + COMMAND grep location + COMMAND awk "{print $4 \"/mlx\"}" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE MLX_ROOT) +find_package(MLX CONFIG REQUIRED) + +add_executable(eval_mlp eval_mlp.cpp) +target_link_libraries(eval_mlp PRIVATE mlx) + +add_executable(train_mlp train_mlp.cpp) +target_link_libraries(train_mlp PRIVATE mlx) diff --git a/examples/export/README.md b/examples/export/README.md new file mode 100644 index 000000000..4d9f77cdc --- /dev/null +++ b/examples/export/README.md @@ -0,0 +1,49 @@ +## Setup + +Install MLX: + +```bash +pip install mlx>=0.22 +``` + +Build the C++ examples: + +```bash +cmake -B build -DCMAKE_BUILD_TYPE=Release +cmake --build build +``` + +## Run + +### Eval MLP + +Run the Python script to export the eval function: + +```bash +python eval_mlp.py +``` + +Then run the C++ program to import and run the function: + +``` +./build/eval_mlp +``` + +The Python and C++ programs should output the same result. + +### Train MLP + +Run the Python script to export the model initialization and training +functions: + +```bash +python train_mlp.py +``` + +Then run the C++ program to import and run the functions: + +``` +./build/train_mlp +``` + +The Python and C++ programs should output the same results. diff --git a/examples/export/eval_mlp.cpp b/examples/export/eval_mlp.cpp new file mode 100644 index 000000000..2facae43a --- /dev/null +++ b/examples/export/eval_mlp.cpp @@ -0,0 +1,25 @@ +// Copyright © 2024 Apple Inc. + +#include +#include + +using namespace mlx::core; + +int main() { + int batch_size = 8; + int input_dim = 32; + + // Make the input + random::seed(42); + auto example_x = random::uniform({batch_size, input_dim}); + + // Import the function + auto forward = import_function("eval_mlp.mlxfn"); + + // Call the imported function + auto out = forward({example_x})[0]; + + std::cout << out << std::endl; + + return 0; +} diff --git a/examples/export/eval_mlp.py b/examples/export/eval_mlp.py new file mode 100644 index 000000000..133ab6917 --- /dev/null +++ b/examples/export/eval_mlp.py @@ -0,0 +1,52 @@ +# Copyright © 2024 Apple Inc. + +import mlx.core as mx +import mlx.nn as nn +import mlx.utils + + +class MLP(nn.Module): + """A simple MLP.""" + + def __init__( + self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int + ): + super().__init__() + layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim] + self.layers = [ + nn.Linear(idim, odim) + for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]) + ] + + def __call__(self, x): + for l in self.layers[:-1]: + x = nn.relu(l(x)) + return self.layers[-1](x) + + +if __name__ == "__main__": + + batch_size = 8 + input_dim = 32 + output_dim = 10 + + # Load the model + mx.random.seed(0) # Seed for params + model = MLP(num_layers=5, input_dim=input_dim, hidden_dim=64, output_dim=output_dim) + mx.eval(model) + + # Note, the model parameters are saved in the export function + def forward(x): + return model(x) + + mx.random.seed(42) # Seed for input + example_x = mx.random.uniform(shape=(batch_size, input_dim)) + + mx.export_function("eval_mlp.mlxfn", forward, example_x) + + # Import in Python + imported_forward = mx.import_function("eval_mlp.mlxfn") + expected = forward(example_x) + (out,) = imported_forward(example_x) + assert mx.allclose(expected, out) + print(out) diff --git a/examples/export/train_mlp.cpp b/examples/export/train_mlp.cpp new file mode 100644 index 000000000..c3d516e9e --- /dev/null +++ b/examples/export/train_mlp.cpp @@ -0,0 +1,35 @@ +// Copyright © 2024 Apple Inc. + +#include +#include + +using namespace mlx::core; + +int main() { + int batch_size = 8; + int input_dim = 32; + int output_dim = 10; + + auto state = import_function("init_mlp.mlxfn")({}); + + // Make the input + random::seed(42); + auto example_X = random::normal({batch_size, input_dim}); + auto example_y = random::randint(0, output_dim, {batch_size}); + + // Import the function + auto step = import_function("train_mlp.mlxfn"); + + // Call the imported function + for (int it = 0; it < 100; ++it) { + state.insert(state.end(), {example_X, example_y}); + state = step(state); + eval(state); + auto loss = state.back(); + state.pop_back(); + if (it % 10 == 0) { + std::cout << "Loss " << loss.item() << std::endl; + } + } + return 0; +} diff --git a/examples/export/train_mlp.py b/examples/export/train_mlp.py new file mode 100644 index 000000000..0e19650a1 --- /dev/null +++ b/examples/export/train_mlp.py @@ -0,0 +1,76 @@ +# Copyright © 2024 Apple Inc. + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +import mlx.utils + + +class MLP(nn.Module): + """A simple MLP.""" + + def __init__( + self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int + ): + super().__init__() + layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim] + self.layers = [ + nn.Linear(idim, odim) + for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:]) + ] + + def __call__(self, x): + for l in self.layers[:-1]: + x = nn.relu(l(x)) + return self.layers[-1](x) + + +if __name__ == "__main__": + + batch_size = 8 + input_dim = 32 + output_dim = 10 + + def init(): + # Seed for the parameter initialization + mx.random.seed(0) + model = MLP( + num_layers=3, input_dim=input_dim, hidden_dim=64, output_dim=output_dim + ) + optimizer = optim.SGD(learning_rate=1e-1) + optimizer.init(model.parameters()) + state = [model.parameters(), optimizer.state] + tree_structure, state = zip(*mlx.utils.tree_flatten(state)) + return model, optimizer, tree_structure, state + + # Export the model parameter initialization + model, optimizer, tree_structure, state = init() + mx.eval(state) + mx.export_function("init_mlp.mlxfn", lambda: init()[-1]) + + def loss_fn(params, X, y): + model.update(params) + return nn.losses.cross_entropy(model(X), y, reduction="mean") + + def step(*inputs): + *state, X, y = inputs + params, opt_state = mlx.utils.tree_unflatten(list(zip(tree_structure, state))) + optimizer.state = opt_state + loss, grads = mx.value_and_grad(loss_fn)(params, X, y) + params = optimizer.apply_gradients(grads, params) + _, state = zip(*mlx.utils.tree_flatten([params, optimizer.state])) + return *state, loss + + # Make some random data + mx.random.seed(42) + example_X = mx.random.normal(shape=(batch_size, input_dim)) + example_y = mx.random.randint(low=0, high=output_dim, shape=(batch_size,)) + mx.export_function("train_mlp.mlxfn", step, *state, example_X, example_y) + + # Export one step of SGD + imported_step = mx.import_function("train_mlp.mlxfn") + + for it in range(100): + *state, loss = imported_step(*state, example_X, example_y) + if it % 10 == 0: + print(f"Loss {loss.item():.6}") diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 8ec82d177..c7ef4670f 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -5,6 +5,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/export.cpp ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 35b460f2d..924dc80d6 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -156,9 +156,6 @@ CompileMode& compile_mode() { return compile_mode_; } -using ParentsMap = - std::unordered_map>>; - // Helper like below but only merges the two provided arrays. If the src has // siblings then these won't be merged to the dst. void merge_one(array& dst, array& src, ParentsMap& parents_map) { @@ -732,10 +729,15 @@ std::vector compile_replace( trace_to_real.insert({trace_inputs[i].id(), inputs[i]}); } + auto is_load = [](const Primitive& p) { return typeid(p) == typeid(Load); }; + for (auto& a : tape) { - // Arrays in the tape without primitives are constants - // and can be used directly - if (!a.has_primitive()) { + // Arrays in the tape without primitives are either: + // - inputs, which are already in the map + // - constants, which can be used directly + // - a load primitive which has no inputs and will become a constant + // after the first eval + if (!a.has_primitive() || is_load(a.primitive())) { trace_to_real.insert({a.id(), a}); } else { // Find real inputs diff --git a/mlx/compile_impl.h b/mlx/compile_impl.h index 913079bfb..f76cc463f 100644 --- a/mlx/compile_impl.h +++ b/mlx/compile_impl.h @@ -2,7 +2,7 @@ #pragma once -#include "mlx/device.h" +#include "mlx/array.h" namespace mlx::core::detail { @@ -22,4 +22,35 @@ void compile_erase(std::uintptr_t fun_id); void compile_clear_cache(); bool compile_available_for_device(const Device& device); + +std::pair, std::vector> compile_trace( + const std::function(const std::vector&)>& fun, + const std::vector& inputs); + +using ParentsMap = + std::unordered_map>>; + +// Traverses the graph to build a tape and a map of array ids to their parents +std::pair, ParentsMap> compile_dfs( + const std::vector& inputs, + const std::vector& outputs, + const std::vector& original_inputs); + +// Simplify the tape. Note, this function modifies in-place both the tape and +// the parents map to remove orphaned arrays +void compile_simplify( + std::vector& tape, + ParentsMap& parents_map, + const std::vector& outputs, + int passes); + +std::vector compile_replace( + const std::vector& tape, + const std::vector& trace_inputs, + const std::vector& trace_outputs, + const std::vector& inputs, + bool shapeless); + +void compile_validate_shapeless(const std::vector& tape); + } // namespace mlx::core::detail diff --git a/mlx/export.cpp b/mlx/export.cpp new file mode 100644 index 000000000..377c6c601 --- /dev/null +++ b/mlx/export.cpp @@ -0,0 +1,867 @@ +// Copyright © 2024 Apple Inc. +#include "mlx/export.h" +#include "mlx/compile_impl.h" +#include "mlx/fast_primitives.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +#define STRINGIFY(x) #x +#define TOSTRING(x) STRINGIFY(x) + +// clang-format off +#define SERIALIZE_PRIMITIVE(primitive, keys...) \ + { \ + #primitive, { \ + [](Writer& os, const Primitive& p) { \ + serialize_primitive(os, p); \ + }, \ + [](Reader& is, Stream s) { \ + return deserialize_primitive(is, s); \ + }, \ + {keys} \ + } \ + } +// clang-format on + +bool is_big_endian() { + int num = 1; + return *reinterpret_cast(&num) != 1; +} + +namespace mlx::core { + +using namespace mlx::core::fast; + +using Reader = io::ParallelFileReader; +using Writer = io::FileWriter; + +struct PrimitiveSerializer { + using Serializer = std::function; + using Deserializer = + std::function(Reader&, Stream s)>; + PrimitiveSerializer( + Serializer serialize, + Deserializer deserialize, + std::vector keys = {}) + : serialize(std::move(serialize)), + deserialize(std::move(deserialize)), + keys(std::move(keys)) {}; + Serializer serialize; + Deserializer deserialize; + std::vector keys; +}; + +template +constexpr bool is_iterable = false; + +template +constexpr bool is_iterable< + T, + std::void_t< + decltype(std::declval().begin()), + decltype(std::declval().end())>> = true; + +template