mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
export with callback
This commit is contained in:
@@ -2,7 +2,9 @@
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/tuple.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <fstream>
|
||||
@@ -129,6 +131,32 @@ auto wrap_export_function(nb::callable fun) {
|
||||
}
|
||||
|
||||
void init_export(nb::module_& m) {
|
||||
m.def(
|
||||
"export_function",
|
||||
[](const nb::callable& callback,
|
||||
const nb::callable& fun,
|
||||
const nb::args& args,
|
||||
bool shapeless,
|
||||
const nb::kwargs& kwargs) {
|
||||
auto [args_, kwargs_] =
|
||||
validate_and_extract_inputs(args, kwargs, "[export_function]");
|
||||
auto wrapped_callback =
|
||||
[callback](const mx::ExportCallbackInput& input) {
|
||||
return callback(input);
|
||||
};
|
||||
mx::export_function(
|
||||
wrapped_callback,
|
||||
wrap_export_function(fun),
|
||||
args_,
|
||||
kwargs_,
|
||||
shapeless);
|
||||
},
|
||||
"callback"_a,
|
||||
"fun"_a,
|
||||
"args"_a,
|
||||
nb::kw_only(),
|
||||
"shapeless"_a = false,
|
||||
"kwargs"_a);
|
||||
m.def(
|
||||
"export_function",
|
||||
[](const std::string& file,
|
||||
|
||||
Reference in New Issue
Block a user