export with callback

This commit is contained in:
Awni Hannun
2025-09-22 12:57:18 -07:00
parent a95d4a74d9
commit 9fcfcf04c6
3 changed files with 89 additions and 54 deletions

View File

@@ -14,6 +14,7 @@
#primitive, { \ #primitive, { \
serialize_primitive<primitive>, \ serialize_primitive<primitive>, \
deserialize_primitive<primitive>, \ deserialize_primitive<primitive>, \
primitive_state<primitive>, \
{__VA_ARGS__} \ {__VA_ARGS__} \
} \ } \
} }
@@ -35,15 +36,20 @@ struct PrimitiveSerializer {
using Serializer = std::function<void(Writer&, const Primitive&)>; using Serializer = std::function<void(Writer&, const Primitive&)>;
using Deserializer = using Deserializer =
std::function<std::shared_ptr<Primitive>(Reader&, Stream s)>; std::function<std::shared_ptr<Primitive>(Reader&, Stream s)>;
using StateExtractor = std::function<std::vector<StateT>(const Primitive&)>;
PrimitiveSerializer( PrimitiveSerializer(
Serializer serialize, Serializer serialize,
Deserializer deserialize, Deserializer deserialize,
StateExtractor extract_state,
std::vector<std::string> keys = {}) std::vector<std::string> keys = {})
: serialize(std::move(serialize)), : serialize(std::move(serialize)),
deserialize(std::move(deserialize)), deserialize(std::move(deserialize)),
extract_state(std::move(extract_state)),
keys(std::move(keys)) {}; keys(std::move(keys)) {};
Serializer serialize; Serializer serialize;
Deserializer deserialize; Deserializer deserialize;
StateExtractor extract_state;
std::vector<std::string> keys; std::vector<std::string> keys;
}; };
@@ -199,25 +205,26 @@ void serialize_primitive(Writer& os, const Primitive& p) {
} }
} }
// Possible types for a Primitive's state
using StateT = std::variant<int, float, std::vector<int>>;
template <typename T> template <typename T>
void extract_state(T v, std::vector<StateT>& state) { void extract_state(const T state, std::vector<StateT>& unpacked_state) {
if constexpr (std::is_arithmetic_v<T>) { if constexpr (std::is_arithmetic_v<T>) {
state.push_back(v); unpacked_state.push_back(state);
} else if constexpr (std::is_enum_v<T>) { } else if constexpr (std::is_enum_v<T>) {
state.push_back(static_cast<int>(v)); unpacked_state.push_back(static_cast<int>(state));
} else if constexpr (is_iterable<T>) {
unpacked_state.push_back(state);
} else if constexpr (is_pair<T> || is_tuple<T>) {
std::apply(
[&unpacked_state](auto&... x) {
(..., extract_state(x, unpacked_state));
},
state);
} }
// } else if constexpr (is_iterable<T>) {
// state.push_back(v);
// } else if constexpr (is_pair<T> || is_tuple<T>) {
// std::apply([&os](auto&... x) { (..., extract_state(os, state)); }, v);
// }
} }
// std::vector<StateT> extract_state(const Primitive& p) {
template <typename T> template <typename T>
std::vector<StateT> extract_state(const Primitive& p) { std::vector<StateT> primitive_state(const Primitive& p) {
std::vector<StateT> state; std::vector<StateT> state;
if constexpr (has_state<T>) { if constexpr (has_state<T>) {
extract_state(static_cast<const T&>(p).state(), state); extract_state(static_cast<const T&>(p).state(), state);
@@ -410,6 +417,23 @@ struct PrimitiveFactory {
"[import_function] Unable to deserialize primitive " + name); "[import_function] Unable to deserialize primitive " + name);
} }
}; };
std::pair<std::string, std::vector<StateT>> extract_state(
const std::shared_ptr<Primitive>& p) {
std::string name = p->name();
name = name.substr(0, name.find(' '));
if (auto it = name_remap.find(name); it != name_remap.end()) {
name = it->second;
}
if (auto it = factory.find(name); it != factory.end()) {
auto state = it->second.extract_state(*p);
return {name, state};
} else {
throw std::invalid_argument(
"[export_function] Unable to get state for primitive " + name);
}
};
}; };
void write_header(Writer& os, int count, bool shapeless) { void write_header(Writer& os, int count, bool shapeless) {
@@ -609,26 +633,30 @@ void FunctionExporter::export_with_callback(
for (auto& in : inputs) { for (auto& in : inputs) {
input_set.insert(in.id()); input_set.insert(in.id());
} }
std::vector<std::pair<std::string, array>> constants; std::vector<std::pair<std::string, array>> new_constants;
for (auto& arr : tape) { for (auto& arr : tape) {
if (arr.has_primitive() || input_set.find(arr.id()) != input_set.end()) { if (arr.has_primitive() || input_set.find(arr.id()) != input_set.end()) {
continue; continue;
} }
constants.emplace_back(namer.get_name(arr), arr); if (constants.insert(arr.id()).second) {
new_constants.emplace_back(namer.get_name(arr), arr);
} }
callback({{"constants", constants}});
} }
callback({{"constants", new_constants}});
}
auto factory = PrimitiveFactory();
// Callback for each primitive in the tape
for (auto& arr : tape) { for (auto& arr : tape) {
if (!arr.has_primitive()) { if (!arr.has_primitive()) {
continue; continue;
} }
callback({ auto [name, state] = factory.extract_state(arr.primitive_ptr());
{"inputs", to_vector_data(arr.inputs())}, callback(
{{"inputs", to_vector_data(arr.inputs())},
{"outputs", to_vector_data(arr.outputs())}, {"outputs", to_vector_data(arr.outputs())},
{"name", arr.primitive().name()} {"primitive", name},
////// {"state": []}); {"state", state}});
});
} }
} }
@@ -676,7 +704,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3); detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3);
// Update the table entry // Update the table entry
fentry.kwarg_keys = std::move(kwarg_keys); fentry.kwarg_keys = kwarg_keys;
fentry.inputs = trace_inputs; fentry.inputs = trace_inputs;
count++; count++;

View File

@@ -4,17 +4,34 @@
#include <set> #include <set>
#include <unordered_map> #include <unordered_map>
#include <variant>
#include "mlx/array.h" #include "mlx/array.h"
namespace mlx::core { namespace mlx::core {
using Args = std::vector<array>; using Args = std::vector<array>;
using Kwargs = std::unordered_map<std::string, array>; using Kwargs = std::unordered_map<std::string, array>;
// Possible types for a Primitive's state
using StateT = std::variant<
bool,
int,
size_t,
float,
double,
Dtype,
Shape,
Strides,
std::vector<int>,
std::vector<size_t>,
std::string>;
using ExportCallbackInput = std::unordered_map< using ExportCallbackInput = std::unordered_map<
std::string, std::string,
std::variant< std::variant<
std::vector<std::tuple<std::string, Shape, Dtype>>, std::vector<std::tuple<std::string, Shape, Dtype>>,
std::vector<std::pair<std::string, array>>, std::vector<std::pair<std::string, array>>,
std::vector<StateT>,
std::string>>; std::string>>;
using ExportCallback = std::function<void(const ExportCallbackInput&)>; using ExportCallback = std::function<void(const ExportCallbackInput&)>;

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include <nanobind/nanobind.h> #include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h> #include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h> #include <nanobind/stl/string.h>
#include <nanobind/stl/tuple.h> #include <nanobind/stl/tuple.h>
#include <nanobind/stl/unordered_map.h> #include <nanobind/stl/unordered_map.h>
@@ -133,50 +134,38 @@ auto wrap_export_function(nb::callable fun) {
void init_export(nb::module_& m) { void init_export(nb::module_& m) {
m.def( m.def(
"export_function", "export_function",
[](const nb::callable& callback, [](nb::object& file_or_callback,
const nb::callable& fun, const nb::callable& fun,
const nb::args& args, const nb::args& args,
bool shapeless, bool shapeless,
const nb::kwargs& kwargs) { const nb::kwargs& kwargs) {
auto [args_, kwargs_] = auto [args_, kwargs_] =
validate_and_extract_inputs(args, kwargs, "[export_function]"); validate_and_extract_inputs(args, kwargs, "[export_function]");
if (nb::isinstance<nb::str>(file_or_callback)) {
mx::export_function(
nb::cast<std::string>(file_or_callback),
wrap_export_function(fun),
args_,
kwargs_,
shapeless);
} else {
auto callback = nb::cast<nb::callable>(file_or_callback);
auto wrapped_callback = auto wrapped_callback =
[callback](const mx::ExportCallbackInput& input) { [callback](const mx::ExportCallbackInput& input) {
return callback(input); return callback(input);
}; };
mx::export_function( mx::export_function(
wrapped_callback, callback, wrap_export_function(fun), args_, kwargs_, shapeless);
wrap_export_function(fun), }
args_,
kwargs_,
shapeless);
}, },
"callback"_a, nb::arg(),
"fun"_a,
"args"_a,
nb::kw_only(),
"shapeless"_a = false,
"kwargs"_a);
m.def(
"export_function",
[](const std::string& file,
const nb::callable& fun,
const nb::args& args,
bool shapeless,
const nb::kwargs& kwargs) {
auto [args_, kwargs_] =
validate_and_extract_inputs(args, kwargs, "[export_function]");
mx::export_function(
file, wrap_export_function(fun), args_, kwargs_, shapeless);
},
"file"_a,
"fun"_a, "fun"_a,
"args"_a, "args"_a,
nb::kw_only(), nb::kw_only(),
"shapeless"_a = false, "shapeless"_a = false,
"kwargs"_a, "kwargs"_a,
R"pbdoc( R"pbdoc(
Export a function to a file. Export an MLX function.
Example input arrays must be provided to export a function. The example Example input arrays must be provided to export a function. The example
inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays
@@ -189,7 +178,8 @@ void init_export(nb::module_& m) {
versions of MLX may not be compatible with future versions. versions of MLX may not be compatible with future versions.
Args: Args:
file (str): File path to export the function to. file (str or Callable): Either a file path to export the function
to or a callback.
fun (Callable): A function which takes as input zero or more fun (Callable): A function which takes as input zero or more
:class:`array` and returns one or more :class:`array`. :class:`array` and returns one or more :class:`array`.
*args (array): Example array inputs to the function. *args (array): Example array inputs to the function.