Export with callback (#2612)

* export with callback

* export with callback

* Add types, fix kwarg ordering bug + test

* cleanup, test, fix

* typos
This commit is contained in:
Awni Hannun
2025-10-08 19:24:33 -07:00
committed by GitHub
parent 85a8824a8c
commit e89e8b4272
10 changed files with 370 additions and 33 deletions

View File

@@ -1,8 +1,11 @@
// Copyright © 2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include <fstream>
@@ -131,24 +134,38 @@ auto wrap_export_function(nb::callable fun) {
void init_export(nb::module_& m) {
m.def(
"export_function",
[](const std::string& file,
[](nb::object& file_or_callback,
const nb::callable& fun,
const nb::args& args,
bool shapeless,
const nb::kwargs& kwargs) {
auto [args_, kwargs_] =
validate_and_extract_inputs(args, kwargs, "[export_function]");
mx::export_function(
file, wrap_export_function(fun), args_, kwargs_, shapeless);
if (nb::isinstance<nb::str>(file_or_callback)) {
mx::export_function(
nb::cast<std::string>(file_or_callback),
wrap_export_function(fun),
args_,
kwargs_,
shapeless);
} else {
auto callback = nb::cast<nb::callable>(file_or_callback);
auto wrapped_callback =
[callback](const mx::ExportCallbackInput& input) {
return callback(input);
};
mx::export_function(
callback, wrap_export_function(fun), args_, kwargs_, shapeless);
}
},
"file"_a,
nb::arg(),
"fun"_a,
"args"_a,
nb::kw_only(),
"shapeless"_a = false,
"kwargs"_a,
R"pbdoc(
Export a function to a file.
Export an MLX function.
Example input arrays must be provided to export a function. The example
inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays
@@ -161,7 +178,8 @@ void init_export(nb::module_& m) {
versions of MLX may not be compatible with future versions.
Args:
file (str): File path to export the function to.
file (str or Callable): Either a file path to export the function
to or a callback.
fun (Callable): A function which takes as input zero or more
:class:`array` and returns one or more :class:`array`.
*args (array): Example array inputs to the function.