Export / import functions to / from a file (#1642)

* export and import functions

* refactor + works for few primitives

* nit

* allow primitives with state

* nit

* nit

* simplify serialize / deserialize

* fix for constants

* python bindings

* maybe fix serialize failure case

* add example

* more primitives, training kind of works

* same result for python and c++

* some fixes

* fix export

* template it up

* some simplificatoin

* rebase

* allow kwargs and multiple functions

* exporter

* more primitives for exporting

* deal with endianness

* handle invalid stream

* add docstring
This commit is contained in:
Awni Hannun
2024-12-24 11:19:13 -08:00
committed by GitHub
parent 935c8c4bb1
commit 4ba0c24a8f
35 changed files with 2239 additions and 90 deletions

View File

@@ -11,6 +11,7 @@ nanobind_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp

View File

@@ -331,7 +331,7 @@ PyScalarT validate_shape(
t = pycomplex;
} else {
std::ostringstream msg;
msg << "Invalid type " << nb::type_name(l.type()).c_str()
msg << "Invalid type " << nb::type_name(l.type()).c_str()
<< " received in array initialization.";
throw std::invalid_argument(msg.str());
}

277
python/src/export.cpp Normal file
View File

@@ -0,0 +1,277 @@
// Copyright © 2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/map.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
#include <fstream>
#include "mlx/array.h"
#include "mlx/export.h"
#include "mlx/graph_utils.h"
#include "python/src/trees.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
std::pair<std::vector<mx::array>, std::map<std::string, mx::array>>
validate_and_extract_inputs(
const nb::args& args,
const nb::kwargs& kwargs,
const std::string& prefix) {
auto maybe_throw = [&prefix](bool valid) {
if (!valid) {
throw std::invalid_argument(
prefix +
" Inputs can either be a variable "
"number of positional and keyword arrays or a single tuple "
"and/or dictionary of arrays.");
}
};
std::vector<mx::array> args_;
std::map<std::string, mx::array> kwargs_;
if (args.size() == 0) {
// No args so kwargs must be keyword arrays
maybe_throw(nb::try_cast(kwargs, kwargs_));
} else if (args.size() > 0 && nb::isinstance<mx::array>(args[0])) {
// Args are positional arrays and kwargs are keyword arrays
maybe_throw(nb::try_cast(args, args_));
maybe_throw(nb::try_cast(kwargs, kwargs_));
} else if (args.size() == 1) {
// - args[0] can be a tuple or list or arrays or a dict
// with string keys and array values
// - kwargs should be empty
maybe_throw(kwargs.size() == 0);
if (!nb::try_cast(args[0], args_)) {
maybe_throw(nb::try_cast(args[0], kwargs_));
}
} else if (args.size() == 2) {
// - args[0] can be a tuple or list of arrays
// - args[1] can be a dict of string keys with array values.
// - kwargs should be empty
maybe_throw(kwargs.size() == 0);
maybe_throw(nb::try_cast(args[0], args_));
maybe_throw(nb::try_cast(args[1], kwargs_));
} else {
maybe_throw(false);
}
return {args_, kwargs_};
}
auto wrap_export_function(const nb::callable& fun) {
return [fun](
const std::vector<mx::array>& args_,
const std::map<std::string, mx::array>& kwargs_) {
auto kwargs = nb::dict();
kwargs.update(nb::cast(kwargs_));
auto args = nb::tuple(nb::cast(args_));
auto outputs = fun(*args, **kwargs);
std::vector<mx::array> outputs_;
if (nb::isinstance<mx::array>(outputs)) {
outputs_.push_back(nb::cast<mx::array>(outputs));
} else if (!nb::try_cast(outputs, outputs_)) {
throw std::invalid_argument(
"[export_function] Outputs can be either a single array "
"a tuple or list of arrays.");
}
return outputs_;
};
}
void init_export(nb::module_& m) {
m.def(
"export_function",
[](const std::string& file,
const nb::callable& fun,
const nb::args& args,
bool shapeless,
const nb::kwargs& kwargs) {
auto [args_, kwargs_] =
validate_and_extract_inputs(args, kwargs, "[export_function]");
mx::export_function(
file, wrap_export_function(fun), args_, kwargs_, shapeless);
},
"file"_a,
"fun"_a,
"args"_a,
nb::kw_only(),
"shapeless"_a = false,
"kwargs"_a,
R"pbdoc(
Export a function to a file.
Example input arrays must be provided to export a function. The example
inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays
and/or dictionary of string keys with array values.
.. warning::
This is part of an experimental API which is likely to
change in future versions of MLX. Functions exported with older
versions of MLX may not be compatible with future versions.
Args:
file (str): File path to export the function to.
fun (Callable): A function which takes as input zero or more
:class:`array` and returns one or more :class:`array`.
*args (array): Example array inputs to the function.
shapeless (bool, optional): Whether or not the function allows
inputs with variable shapes. Default: ``False``.
**kwargs (array): Additional example keyword array inputs to the
function.
Example:
.. code-block:: python
def fun(x, y):
return x + y
x = mx.array(1)
y = mx.array([1, 2, 3])
mx.export_function("fun.mlxfn", fun, x, y=y)
)pbdoc");
m.def(
"import_function",
[](const std::string& file) {
return nb::cpp_function(
[fn = mx::import_function(file)](
const nb::args& args, const nb::kwargs& kwargs) {
auto [args_, kwargs_] = validate_and_extract_inputs(
args, kwargs, "[import_function::call]");
return nb::tuple(nb::cast(fn(args_, kwargs_)));
});
},
"file"_a,
nb::sig("def import_function(file: str) -> Callable"),
R"pbdoc(
Import a function from a file.
The imported function can be called either with ``*args`` and
``**kwargs`` or with a tuple of arrays and/or dictionary of string
keys with array values. Imported functions always return a tuple of
arrays.
.. warning::
This is part of an experimental API which is likely to
change in future versions of MLX. Functions exported with older
versions of MLX may not be compatible with future versions.
Args:
file (str): The file path to import the function from.
Returns:
Callable: The imported function.
Example:
>>> fn = mx.import_function("function.mlxfn")
>>> out = fn(a, b, x=x, y=y)[0]
>>>
>>> out = fn((a, b), {"x": x, "y": y}[0]
)pbdoc");
nb::class_<mx::FunctionExporter>(
m,
"FunctionExporter",
R"pbdoc(
A context managing class for exporting multiple traces of the same
function to a file.
Make an instance of this class by calling fun:`mx.exporter`.
)pbdoc")
.def("close", &mx::FunctionExporter::close)
.def(
"__enter__", [](mx::FunctionExporter& exporter) { return &exporter; })
.def(
"__exit__",
[](mx::FunctionExporter& exporter,
const std::optional<nb::object>&,
const std::optional<nb::object>&,
const std::optional<nb::object>&) { exporter.close(); },
"exc_type"_a = nb::none(),
"exc_value"_a = nb::none(),
"traceback"_a = nb::none())
.def(
"__call__",
[](mx::FunctionExporter& exporter,
const nb::args& args,
const nb::kwargs& kwargs) {
auto [args_, kwargs_] =
validate_and_extract_inputs(args, kwargs, "[export_function]");
exporter(args_, kwargs_);
});
m.def(
"exporter",
[](const std::string& file, const nb::callable& fun, bool shapeless) {
return mx::exporter(file, wrap_export_function(fun), shapeless);
},
"file"_a,
"fun"_a,
nb::kw_only(),
"shapeless"_a = false,
R"pbdoc(
Make a callable object to export multiple traces of a function to a file.
.. warning::
This is part of an experimental API which is likely to
change in future versions of MLX. Functions exported with older
versions of MLX may not be compatible with future versions.
Args:
file (str): File path to export the function to.
shapeless (bool, optional): Whether or not the function allows
inputs with variable shapes. Default: ``False``.
Example:
.. code-block:: python
def fun(*args):
return sum(args)
with mx.exporter("fun.mlxfn", fun) as exporter:
exporter(mx.array(1))
exporter(mx.array(1), mx.array(2))
exporter(mx.array(1), mx.array(2), mx.array(3))
)pbdoc");
m.def(
"export_to_dot",
[](nb::object file, const nb::args& args) {
std::vector<mx::array> arrays = tree_flatten(args);
if (nb::isinstance<nb::str>(file)) {
std::ofstream out(nb::cast<std::string>(file));
mx::export_to_dot(out, arrays);
} else if (nb::hasattr(file, "write")) {
std::ostringstream out;
mx::export_to_dot(out, arrays);
auto write = file.attr("write");
write(out.str());
} else {
throw std::invalid_argument(
"[export_to_dot] Accepts file-like objects or strings "
"to be used as filenames.");
}
},
"file"_a,
"args"_a,
R"pbdoc(
Export a graph to DOT format for visualization.
A variable number of output arrays can be provided for exporting
The graph exported will recursively include all enevaluated inputs of
the provided outputs.
Args:
file (str): The file path to export to.
*args (array): The output arrays.
Example:
>>> a = mx.array(1) + mx.array(2)
>>> mx.export_to_dot("graph.dot", a)
)pbdoc");
}

View File

@@ -19,6 +19,7 @@ void init_linalg(nb::module_&);
void init_constants(nb::module_&);
void init_fast(nb::module_&);
void init_distributed(nb::module_&);
void init_export(nb::module_&);
NB_MODULE(core, m) {
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
@@ -39,6 +40,7 @@ NB_MODULE(core, m) {
init_constants(m);
init_fast(m);
init_distributed(m);
init_export(m);
m.attr("__version__") = TOSTRING(_VERSION_);
}

View File

@@ -2898,7 +2898,7 @@ void init_ops(nb::module_& m) {
Generate multidimensional coordinate grids from 1-D coordinate arrays
Args:
arrays (array): Input arrays.
*arrays (array): Input arrays.
sparse (bool, optional): If ``True``, a sparse grid is returned in which each output
array has a single non-zero element. If ``False``, a dense grid is returned.
Defaults to ``False``.
@@ -3840,8 +3840,8 @@ void init_ops(nb::module_& m) {
Args:
file (file, str): Path to file to which the arrays are saved.
args (arrays): Arrays to be saved.
kwargs (arrays): Arrays to be saved. Each array will be saved
*args (arrays): Arrays to be saved.
**kwargs (arrays): Arrays to be saved. Each array will be saved
with the associated keyword as the output file name.
)pbdoc");
m.def(

View File

@@ -8,14 +8,12 @@
#include <nanobind/stl/vector.h>
#include <algorithm>
#include <fstream>
#include <numeric>
#include <sstream>
#include "mlx/array.h"
#include "mlx/compile.h"
#include "mlx/compile_impl.h"
#include "mlx/graph_utils.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
@@ -945,34 +943,34 @@ void init_transforms(nb::module_& m) {
Note, all custom transformations are optional. Undefined transformations
fall back to the default behaviour.
Example usage:
Example:
.. code-block:: python
.. code-block:: python
import mlx.core as mx
import mlx.core as mx
@mx.custom_function
def f(x, y):
return mx.sin(x) * y
@mx.custom_function
def f(x, y):
return mx.sin(x) * y
@f.vjp
def f_vjp(primals, cotangent, output):
@f.vjp
def f_vjp(primals, cotangent, output):
x, y = primals
return cotan * mx.cos(x) * y, cotan * mx.sin(x)
@f.jvp
def f_jvp(primals, tangents):
x, y = primals
return cotan * mx.cos(x) * y, cotan * mx.sin(x)
dx, dy = tangents
return dx * mx.cos(x) * y + dy * mx.sin(x)
@f.jvp
def f_jvp(primals, tangents):
x, y = primals
dx, dy = tangents
return dx * mx.cos(x) * y + dy * mx.sin(x)
@f.vmap
def f_vmap(inputs, axes):
x, y = inputs
ax, ay = axes
if ay != ax and ax is not None:
y = y.swapaxes(ay, ax)
return mx.sin(x) * y, (ax or ay)
@f.vmap
def f_vmap(inputs, axes):
x, y = inputs
ax, ay = axes
if ay != ax and ax is not None:
y = y.swapaxes(ay, ax)
return mx.sin(x) * y, (ax or ay)
)pbdoc")
.def(
nb::init<nb::callable>(),
@@ -1313,25 +1311,6 @@ void init_transforms(nb::module_& m) {
Returns:
Callable: The vectorized function.
)pbdoc");
m.def(
"export_to_dot",
[](nb::object file, const nb::args& args) {
std::vector<mx::array> arrays = tree_flatten(args);
if (nb::isinstance<nb::str>(file)) {
std::ofstream out(nb::cast<std::string>(file));
export_to_dot(out, arrays);
} else if (nb::hasattr(file, "write")) {
std::ostringstream out;
export_to_dot(out, arrays);
auto write = file.attr("write");
write(out.str());
} else {
throw std::invalid_argument(
"export_to_dot accepts file-like objects or strings to be used as filenames");
}
},
"file"_a,
"args"_a);
m.def(
"compile",
[](const nb::callable& fun,

View File

@@ -0,0 +1,244 @@
# Copyright © 2024 Apple Inc.
import os
import tempfile
import unittest
import mlx.core as mx
import mlx_tests
class TestExportImport(mlx_tests.MLXTestCase):
@classmethod
def setUpClass(cls):
cls.test_dir_fid = tempfile.TemporaryDirectory()
cls.test_dir = cls.test_dir_fid.name
if not os.path.isdir(cls.test_dir):
os.mkdir(cls.test_dir)
@classmethod
def tearDownClass(cls):
cls.test_dir_fid.cleanup()
def test_basic_export_import(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
# Function with no inputs
def fun():
return mx.zeros((3, 3))
mx.export_function(path, fun)
imported = mx.import_function(path)
expected = fun()
(out,) = imported()
self.assertTrue(mx.array_equal(out, expected))
# Simple function with inputs
def fun(x):
return mx.abs(mx.sin(x))
inputs = mx.array([1.0, 2.0, 3.0, 4.0, 5.0])
mx.export_function(path, fun, inputs)
imported = mx.import_function(path)
expected = fun(inputs)
(out,) = imported(inputs)
self.assertTrue(mx.allclose(out, expected))
# Inputs in a list or tuple
def fun(x):
x = mx.abs(mx.sin(x))
return x
mx.export_function(path, fun, [inputs])
imported = mx.import_function(path)
expected = fun(inputs)
(out,) = imported([inputs])
self.assertTrue(mx.allclose(out, expected))
(out,) = imported(inputs)
self.assertTrue(mx.allclose(out, expected))
mx.export_function(path, fun, (inputs,))
imported = mx.import_function(path)
(out,) = imported((inputs,))
self.assertTrue(mx.allclose(out, expected))
# Outputs in a list
def fun(x):
return [mx.abs(mx.sin(x))]
mx.export_function(path, fun, inputs)
imported = mx.import_function(path)
(out,) = imported(inputs)
self.assertTrue(mx.allclose(out, expected))
# Outputs in a tuple
def fun(x):
return (mx.abs(mx.sin(x)),)
mx.export_function(path, fun, inputs)
imported = mx.import_function(path)
(out,) = imported(inputs)
self.assertTrue(mx.allclose(out, expected))
# Check throws on invalid inputs / outputs
def fun(x):
return mx.abs(x)
with self.assertRaises(ValueError):
mx.export_function(path, fun, "hi")
with self.assertRaises(ValueError):
mx.export_function(path, fun, mx.array(1.0), "hi")
def fun(x):
return mx.abs(x[0][0])
with self.assertRaises(ValueError):
mx.export_function(path, fun, [[mx.array(1.0)]])
def fun():
return (mx.zeros((3, 3)), 1)
with self.assertRaises(ValueError):
mx.export_function(path, fun)
def fun():
return (mx.zeros((3, 3)), [mx.zeros((3, 3))])
with self.assertRaises(ValueError):
mx.export_function(path, fun)
def fun(x, y):
return x + y
mx.export_function(path, fun, mx.array(1.0), mx.array(1.0))
imported = mx.import_function(path)
with self.assertRaises(ValueError):
imported(mx.array(1.0), 1.0)
with self.assertRaises(ValueError):
imported(mx.array(1.0), mx.array(1.0), mx.array(1.0))
with self.assertRaises(ValueError):
imported(mx.array(1.0), [mx.array(1.0)])
def test_export_random_sample(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
mx.random.seed(5)
def fun():
return mx.random.uniform(shape=(3,))
mx.export_function(path, fun)
imported = mx.import_function(path)
(out,) = imported()
mx.random.seed(5)
expected = fun()
self.assertTrue(mx.array_equal(out, expected))
def test_export_with_kwargs(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
def fun(x, z=None):
out = x
if z is not None:
out += z
return out
x = mx.array([1, 2, 3])
y = mx.array([1, 1, 0])
z = mx.array([2, 2, 2])
mx.export_function(path, fun, (x,), {"z": z})
imported_fun = mx.import_function(path)
with self.assertRaises(ValueError):
imported_fun(x, z)
with self.assertRaises(ValueError):
imported_fun(x, y=z)
with self.assertRaises(ValueError):
imported_fun((x,), {"y": z})
out = imported_fun(x, z=z)[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
out = imported_fun((x,), {"z": z})[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
mx.export_function(path, fun, x, z=z)
imported_fun = mx.import_function(path)
out = imported_fun(x, z=z)[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
out = imported_fun((x,), {"z": z})[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
# Only specify kwargs
mx.export_function(path, fun, x=x, z=z)
imported_fun = mx.import_function(path)
with self.assertRaises(ValueError):
out = imported_fun(x, z=z)[0]
out = imported_fun(x=x, z=z)[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
out = imported_fun({"x": x, "z": z})[0]
self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5])))
def test_export_variable_inputs(self):
path = os.path.join(self.test_dir, "fn.mlxfn")
def fun(x, y, z=None):
out = x + y
if z is not None:
out += z
return out
with mx.exporter(path, fun) as exporter:
exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]))
exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))
with self.assertRaises(RuntimeError):
exporter(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1]))
imported_fun = mx.import_function(path)
out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]))[0]
self.assertTrue(mx.array_equal(out, mx.array([2, 3, 4])))
out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))[0]
self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6])))
with self.assertRaises(ValueError):
imported_fun(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1]))
# A function with a large constant
constant = mx.zeros((16, 2048))
mx.eval(constant)
def fun(*args):
return constant + sum(args)
with mx.exporter(path, fun) as exporter:
for i in range(5):
exporter(*[mx.array(1)] * i)
# Check the exported file size < constant size + small amount
constants_size = constant.nbytes + 8192
self.assertTrue(os.path.getsize(path) < constants_size)
if __name__ == "__main__":
unittest.main()

View File

@@ -28,15 +28,14 @@ class TestLoad(mlx_tests.MLXTestCase):
def setUpClass(cls):
cls.test_dir_fid = tempfile.TemporaryDirectory()
cls.test_dir = cls.test_dir_fid.name
if not os.path.isdir(cls.test_dir):
os.mkdir(cls.test_dir)
@classmethod
def tearDownClass(cls):
cls.test_dir_fid.cleanup()
def test_save_and_load(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
for dt in self.dtypes:
with self.subTest(dtype=dt):
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
@@ -64,9 +63,6 @@ class TestLoad(mlx_tests.MLXTestCase):
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
def test_save_and_load_safetensors(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
test_file = os.path.join(self.test_dir, "test.safetensors")
with self.assertRaises(Exception):
mx.save_safetensors(test_file, {"a": mx.ones((4, 4))}, {"testing": 0})
@@ -330,9 +326,6 @@ class TestLoad(mlx_tests.MLXTestCase):
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
def test_non_contiguous(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
a = mx.broadcast_to(mx.array([1, 2]), [4, 2])
save_file = os.path.join(self.test_dir, "a.npy")