mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-17 23:08:11 +08:00
Export with callback (#2612)
* export with callback * export with callback * Add types, fix kwarg ordering bug + test * cleanup, test, fix * typos
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/pair.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/tuple.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <fstream>
|
||||
@@ -131,24 +134,38 @@ auto wrap_export_function(nb::callable fun) {
|
||||
void init_export(nb::module_& m) {
|
||||
m.def(
|
||||
"export_function",
|
||||
[](const std::string& file,
|
||||
[](nb::object& file_or_callback,
|
||||
const nb::callable& fun,
|
||||
const nb::args& args,
|
||||
bool shapeless,
|
||||
const nb::kwargs& kwargs) {
|
||||
auto [args_, kwargs_] =
|
||||
validate_and_extract_inputs(args, kwargs, "[export_function]");
|
||||
mx::export_function(
|
||||
file, wrap_export_function(fun), args_, kwargs_, shapeless);
|
||||
if (nb::isinstance<nb::str>(file_or_callback)) {
|
||||
mx::export_function(
|
||||
nb::cast<std::string>(file_or_callback),
|
||||
wrap_export_function(fun),
|
||||
args_,
|
||||
kwargs_,
|
||||
shapeless);
|
||||
} else {
|
||||
auto callback = nb::cast<nb::callable>(file_or_callback);
|
||||
auto wrapped_callback =
|
||||
[callback](const mx::ExportCallbackInput& input) {
|
||||
return callback(input);
|
||||
};
|
||||
mx::export_function(
|
||||
callback, wrap_export_function(fun), args_, kwargs_, shapeless);
|
||||
}
|
||||
},
|
||||
"file"_a,
|
||||
nb::arg(),
|
||||
"fun"_a,
|
||||
"args"_a,
|
||||
nb::kw_only(),
|
||||
"shapeless"_a = false,
|
||||
"kwargs"_a,
|
||||
R"pbdoc(
|
||||
Export a function to a file.
|
||||
Export an MLX function.
|
||||
|
||||
Example input arrays must be provided to export a function. The example
|
||||
inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays
|
||||
@@ -161,7 +178,8 @@ void init_export(nb::module_& m) {
|
||||
versions of MLX may not be compatible with future versions.
|
||||
|
||||
Args:
|
||||
file (str): File path to export the function to.
|
||||
file (str or Callable): Either a file path to export the function
|
||||
to or a callback.
|
||||
fun (Callable): A function which takes as input zero or more
|
||||
:class:`array` and returns one or more :class:`array`.
|
||||
*args (array): Example array inputs to the function.
|
||||
|
@@ -319,7 +319,7 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
# Check the state is unchanged
|
||||
self.assertEqual(state["y"], 2)
|
||||
|
||||
# Check the udpated state is used
|
||||
# Check the updated state is used
|
||||
state["y"] = mx.array(3)
|
||||
out = test_state(mx.array(1))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
@@ -485,6 +485,52 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
mx.array_equal(imported_fn(input_data)[0], model(input_data))
|
||||
)
|
||||
|
||||
def test_export_kwarg_ordering(self):
|
||||
path = os.path.join(self.test_dir, "fun.mlxfn")
|
||||
|
||||
def fn(x, y):
|
||||
return x - y
|
||||
|
||||
mx.export_function(path, fn, x=mx.array(1.0), y=mx.array(1.0))
|
||||
imported = mx.import_function(path)
|
||||
out = imported(x=mx.array(2.0), y=mx.array(3.0))[0]
|
||||
self.assertEqual(out.item(), -1.0)
|
||||
out = imported(y=mx.array(2.0), x=mx.array(3.0))[0]
|
||||
self.assertEqual(out.item(), 1.0)
|
||||
|
||||
def test_export_with_callback(self):
|
||||
|
||||
def fn(x, y):
|
||||
return mx.log(mx.abs(x - y))
|
||||
|
||||
n_in = None
|
||||
n_out = None
|
||||
n_const = None
|
||||
keywords = None
|
||||
primitives = []
|
||||
|
||||
def callback(args):
|
||||
nonlocal n_in, n_out, n_const, keywords, primitives
|
||||
t = args["type"]
|
||||
if t == "inputs":
|
||||
n_in = len(args["inputs"])
|
||||
elif args["type"] == "outputs":
|
||||
n_out = len(args["outputs"])
|
||||
elif args["type"] == "keyword_inputs":
|
||||
keywords = args["keywords"]
|
||||
elif t == "constants":
|
||||
n_const = len(args["constants"])
|
||||
elif t == "primitive":
|
||||
primitives.append(args["name"])
|
||||
|
||||
mx.export_function(callback, fn, mx.array(1.0), y=mx.array(1.0))
|
||||
self.assertEqual(n_in, 2)
|
||||
self.assertEqual(n_out, 1)
|
||||
self.assertEqual(n_const, 0)
|
||||
self.assertEqual(len(keywords), 1)
|
||||
self.assertEqual(keywords[0][0], "y")
|
||||
self.assertEqual(primitives, ["Subtract", "Abs", "Log"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
Reference in New Issue
Block a user