diff --git a/docs/src/index.rst b/docs/src/index.rst index 99cb7a8af..075861e88 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -45,6 +45,7 @@ are the CPU and GPU. usage/numpy usage/distributed usage/using_streams + usage/export .. toctree:: :caption: Examples diff --git a/docs/src/usage/compile.rst b/docs/src/usage/compile.rst index d0c13a9a8..7fe0ffd4f 100644 --- a/docs/src/usage/compile.rst +++ b/docs/src/usage/compile.rst @@ -422,6 +422,10 @@ the most opportunity to optimize the computation graph: # be faster even though the inner functions are compiled fun = mx.compile(outer) + + +.. _shapeless_compile: + Shapeless Compilation --------------------- diff --git a/docs/src/usage/export.rst b/docs/src/usage/export.rst new file mode 100644 index 000000000..812073609 --- /dev/null +++ b/docs/src/usage/export.rst @@ -0,0 +1,288 @@ +.. _export_usage: + +Exporting Functions +=================== + +.. currentmodule:: mlx.core + +MLX has an API to export and import functions to and from a file. This lets you +run computations written in one MLX front-end (e.g. Python) in another MLX +front-end (e.g. C++). + +This guide walks through the basics of the MLX export API with some examples. +To see the full list of functions check-out the :ref:`API documentation +`. + +Basics of Exporting +------------------- + +Let's start with a simple example: + +.. code-block:: python + + def fun(x, y): + return x + y + + x = mx.array(1.0) + y = mx.array(1.0) + mx.export_function("add.mlxfn", fun, x, y) + +To export a function, provide sample input arrays that the function +can be called with. The data doesn't matter, but the shapes and types of the +arrays do. In the above example we exported ``fun`` with two ``float32`` +scalar arrays. We can then import the function and run it: + +.. code-block:: python + + add_fun = mx.import_function("add.mlxfn") + + out, = add_fun(mx.array(1.0), mx.array(2.0)) + # Prints: array(3, dtype=float32) + print(out) + + out, = add_fun(mx.array(1.0), mx.array(3.0)) + # Prints: array(4, dtype=float32) + print(out) + + # Raises an exception + add_fun(mx.array(1), mx.array(3.0)) + + # Raises an exception + add_fun(mx.array([1.0, 2.0]), mx.array(3.0)) + +Notice the third and fourth calls to ``add_fun`` raise exceptions because the +shapes and types of the inputs are different than the shapes and types of the +example inputs we exported the function with. + +Also notice that even though the original ``fun`` returns a single output +array, the imported function always returns a tuple of one or more arrays. + +The inputs to :func:`export_function` and to an imported function can be +specified as variable positional arguments or as a tuple of arrays: + +.. code-block:: python + + def fun(x, y): + return x + y + + x = mx.array(1.0) + y = mx.array(1.0) + + # Both arguments to fun are positional + mx.export_function("add.mlxfn", fun, x, y) + + # Same as above + mx.export_function("add.mlxfn", fun, (x, y)) + + imported_fun = mx.import_function("add.mlxfn") + + # Ok + out, = imported_fun(x, y) + + # Also ok + out, = imported_fun((x, y)) + +You can pass example inputs to functions as positional or keyword arguments. If +you use keyword arguments to export the function, then you have to use the same +keyword arguments when calling the imported function. + +.. code-block:: python + + def fun(x, y): + return x + y + + # One argument to fun is positional, the other is a kwarg + mx.export_function("add.mlxfn", fun, x, y=y) + + imported_fun = mx.import_function("add.mlxfn") + + # Ok + out, = imported_fun(x, y=y) + + # Also ok + out, = imported_fun((x,), {"y": y}) + + # Raises since the keyword argument is missing + out, = imported_fun(x, y) + + # Raises since the keyword argument has the wrong key + out, = imported_fun(x, z=y) + + +Exporting Modules +----------------- + +An :obj:`mlx.nn.Module` can be exported with or without the parameters included +in the exported function. Here's an example: + +.. code-block:: python + + model = nn.Linear(4, 4) + mx.eval(model.parameters()) + + def call(x): + return model(x) + + mx.export_function("model.mlxfn", call, mx.zeros(4)) + +In the above example, the :obj:`mlx.nn.Linear` module is exported. Its +parameters are also saved to the ``model.mlxfn`` file. + +.. note:: + + For enclosed arrays inside an exported function, be extra careful to ensure + they are evaluated. The computation graph that gets exported will include + the computation that produces enclosed inputs. + + If the above example was missing ``mx.eval(model.parameters()``, the + exported function would include the random initialization of the + :obj:`mlx.nn.Module` parameters. + +If you only want to export the ``Module.__call__`` function without the +parameters, pass them as inputs to the ``call`` wrapper: + +.. code-block:: python + + model = nn.Linear(4, 4) + mx.eval(model.parameters()) + + def call(x, **params): + # Set the model's parameters to the input parameters + model.update(tree_unflatten(list(params.items()))) + return model(x) + + params = dict(tree_flatten(model.parameters())) + mx.export_function("model.mlxfn", call, (mx.zeros(4),), params) + + +Shapeless Exports +----------------- + +Just like :func:`compile`, functions can also be exported for dynamically shaped +inputs. Pass ``shapeless=True`` to :func:`export_function` or :func:`exporter` +to export a function which can be used for inputs with variable shapes: + +.. code-block:: python + + mx.export_function("fun.mlxfn", mx.abs, mx.array(0.0), shapeless=True) + imported_abs = mx.import_function("fun.mlxfn") + + # Ok + out, = imported_abs(mx.array(-1.0)) + + # Also ok + out, = imported_abs(mx.array([-1.0, -2.0])) + +With ``shapeless=False`` (which is the default), the second call to +``imported_abs`` would raise an exception with a shape mismatch. + +Shapeless exporting works the same as shapeless compilation and should be +used carefully. See the :ref:`documentation on shapeless compilation +` for more information. + +Exporting Multiple Traces +------------------------- + +In some cases, functions build different computation graphs for different +input arguments. A simple way to manage this is to export to a new file with +each set of inputs. This is a fine option in many cases. But it can be +suboptimal if the exported functions have a large amount of duplicate constant +data (for example the parameters of a :obj:`mlx.nn.Module`). + +The export API in MLX lets you export multiple traces of the same function to +a single file by creating an exporting context manager with :func:`exporter`: + +.. code-block:: python + + def fun(x, y=None): + constant = mx.array(3.0) + if y is not None: + x += y + return x + constant + + with mx.exporter("fun.mlxfn", fun) as exporter: + exporter(mx.array(1.0)) + exporter(mx.array(1.0), y=mx.array(0.0)) + + imported_function = mx.import_function("fun.mlxfn") + + # Call the function with y=None + out, = imported_function(mx.array(1.0)) + print(out) + + # Call the function with y specified + out, = imported_function(mx.array(1.0), y=mx.array(1.0)) + print(out) + +In the above example the function constant data, (i.e. ``constant``), is only +saved once. + +Transformations with Imported Functions +--------------------------------------- + +Function transformations like :func:`grad`, :func:`vmap`, and :func:`compile` work +on imported functions just like regular Python functions: + +.. code-block:: python + + def fun(x): + return mx.sin(x) + + x = mx.array(0.0) + mx.export_function("sine.mlxfn", fun, x) + + imported_fun = mx.import_function("sine.mlxfn") + + # Take the derivative of the imported function + dfdx = mx.grad(lambda x: imported_fun(x)[0]) + # Prints: array(1, dtype=float32) + print(dfdx(x)) + + # Compile the imported function + mx.compile(imported_fun) + # Prints: array(0, dtype=float32) + print(compiled_fun(x)[0]) + + +Importing Functions in C++ +-------------------------- + +Importing and running functions in C++ is basically the same as importing and +running them in Python. First, follow the :ref:`instructions ` to +setup a simple C++ project that uses MLX as a library. + +Next, export a simple function from Python: + +.. code-block:: python + + def fun(x, y): + return mx.exp(x + y) + + x = mx.array(1.0) + y = mx.array(1.0) + mx.export_function("fun.mlxfn", fun, x, y) + + +Import and run the function in C++ with only a few lines of code: + +.. code-block:: c++ + + auto fun = mx::import_function("fun.mlxfn"); + + auto inputs = {mx::array(1.0), mx::array(1.0)}; + auto outputs = fun(inputs); + + // Prints: array(2, dtype=float32) + std::cout << outputs[0] << std::endl; + +Imported functions can be transformed in C++ just like in Python. Use +``std::vector`` for positional arguments and ``std::map`` for keyword arguments when calling imported functions in C++. + +More Examples +------------- + +Here are a few more complete examples exporting more complex functions from +Python and importing and running them in C++: + +* `Inference and training a multi-layer perceptron `_ diff --git a/mlx/export.cpp b/mlx/export.cpp index 743d20409..ae826e61c 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -729,8 +729,8 @@ std::vector ImportedFunction::operator()( msg << "[import_function::call] No imported function found which matches " << "the given positional and keyword arguments. Possible functions include:\n"; ftable->print_functions(msg); - msg << "\nReceived function with " << args.size() - << " positional inputs and " << kwargs.size() << " keyword inputs:\n"; + msg << "\nCalled with " << args.size() << " positional inputs and " + << kwargs.size() << " keyword inputs:\n"; for (int i = 0; i < args.size(); ++i) { auto& in = args[i]; msg << " " << i + 1 << ": " << in.shape() << " " << in.dtype() << "\n";