export with callback

This commit is contained in:
Awni Hannun
2025-08-04 11:48:36 -07:00
parent aa9d44b3d4
commit a95d4a74d9
5 changed files with 248 additions and 6 deletions

View File

@@ -2,7 +2,9 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include <fstream>
@@ -129,6 +131,32 @@ auto wrap_export_function(nb::callable fun) {
}
void init_export(nb::module_& m) {
m.def(
"export_function",
[](const nb::callable& callback,
const nb::callable& fun,
const nb::args& args,
bool shapeless,
const nb::kwargs& kwargs) {
auto [args_, kwargs_] =
validate_and_extract_inputs(args, kwargs, "[export_function]");
auto wrapped_callback =
[callback](const mx::ExportCallbackInput& input) {
return callback(input);
};
mx::export_function(
wrapped_callback,
wrap_export_function(fun),
args_,
kwargs_,
shapeless);
},
"callback"_a,
"fun"_a,
"args"_a,
nb::kw_only(),
"shapeless"_a = false,
"kwargs"_a);
m.def(
"export_function",
[](const std::string& file,