export with callback

This commit is contained in:
Awni Hannun
2025-09-22 12:57:18 -07:00
parent a95d4a74d9
commit 9fcfcf04c6
3 changed files with 89 additions and 54 deletions

View File

@@ -14,6 +14,7 @@
#primitive, { \
serialize_primitive<primitive>, \
deserialize_primitive<primitive>, \
primitive_state<primitive>, \
{__VA_ARGS__} \
} \
}
@@ -35,15 +36,20 @@ struct PrimitiveSerializer {
using Serializer = std::function<void(Writer&, const Primitive&)>;
using Deserializer =
std::function<std::shared_ptr<Primitive>(Reader&, Stream s)>;
using StateExtractor = std::function<std::vector<StateT>(const Primitive&)>;
PrimitiveSerializer(
Serializer serialize,
Deserializer deserialize,
StateExtractor extract_state,
std::vector<std::string> keys = {})
: serialize(std::move(serialize)),
deserialize(std::move(deserialize)),
extract_state(std::move(extract_state)),
keys(std::move(keys)) {};
Serializer serialize;
Deserializer deserialize;
StateExtractor extract_state;
std::vector<std::string> keys;
};
@@ -199,25 +205,26 @@ void serialize_primitive(Writer& os, const Primitive& p) {
}
}
// Possible types for a Primitive's state
using StateT = std::variant<int, float, std::vector<int>>;
template <typename T>
void extract_state(T v, std::vector<StateT>& state) {
void extract_state(const T state, std::vector<StateT>& unpacked_state) {
if constexpr (std::is_arithmetic_v<T>) {
state.push_back(v);
unpacked_state.push_back(state);
} else if constexpr (std::is_enum_v<T>) {
state.push_back(static_cast<int>(v));
unpacked_state.push_back(static_cast<int>(state));
} else if constexpr (is_iterable<T>) {
unpacked_state.push_back(state);
} else if constexpr (is_pair<T> || is_tuple<T>) {
std::apply(
[&unpacked_state](auto&... x) {
(..., extract_state(x, unpacked_state));
},
state);
}
// } else if constexpr (is_iterable<T>) {
// state.push_back(v);
// } else if constexpr (is_pair<T> || is_tuple<T>) {
// std::apply([&os](auto&... x) { (..., extract_state(os, state)); }, v);
// }
}
// std::vector<StateT> extract_state(const Primitive& p) {
template <typename T>
std::vector<StateT> extract_state(const Primitive& p) {
std::vector<StateT> primitive_state(const Primitive& p) {
std::vector<StateT> state;
if constexpr (has_state<T>) {
extract_state(static_cast<const T&>(p).state(), state);
@@ -410,6 +417,23 @@ struct PrimitiveFactory {
"[import_function] Unable to deserialize primitive " + name);
}
};
std::pair<std::string, std::vector<StateT>> extract_state(
const std::shared_ptr<Primitive>& p) {
std::string name = p->name();
name = name.substr(0, name.find(' '));
if (auto it = name_remap.find(name); it != name_remap.end()) {
name = it->second;
}
if (auto it = factory.find(name); it != factory.end()) {
auto state = it->second.extract_state(*p);
return {name, state};
} else {
throw std::invalid_argument(
"[export_function] Unable to get state for primitive " + name);
}
};
};
void write_header(Writer& os, int count, bool shapeless) {
@@ -609,26 +633,30 @@ void FunctionExporter::export_with_callback(
for (auto& in : inputs) {
input_set.insert(in.id());
}
std::vector<std::pair<std::string, array>> constants;
std::vector<std::pair<std::string, array>> new_constants;
for (auto& arr : tape) {
if (arr.has_primitive() || input_set.find(arr.id()) != input_set.end()) {
continue;
}
constants.emplace_back(namer.get_name(arr), arr);
if (constants.insert(arr.id()).second) {
new_constants.emplace_back(namer.get_name(arr), arr);
}
}
callback({{"constants", constants}});
callback({{"constants", new_constants}});
}
auto factory = PrimitiveFactory();
// Callback for each primitive in the tape
for (auto& arr : tape) {
if (!arr.has_primitive()) {
continue;
}
callback({
{"inputs", to_vector_data(arr.inputs())},
{"outputs", to_vector_data(arr.outputs())},
{"name", arr.primitive().name()}
////// {"state": []});
});
auto [name, state] = factory.extract_state(arr.primitive_ptr());
callback(
{{"inputs", to_vector_data(arr.inputs())},
{"outputs", to_vector_data(arr.outputs())},
{"primitive", name},
{"state", state}});
}
}
@@ -676,7 +704,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3);
// Update the table entry
fentry.kwarg_keys = std::move(kwarg_keys);
fentry.kwarg_keys = kwarg_keys;
fentry.inputs = trace_inputs;
count++;

View File

@@ -4,17 +4,34 @@
#include <set>
#include <unordered_map>
#include <variant>
#include "mlx/array.h"
namespace mlx::core {
using Args = std::vector<array>;
using Kwargs = std::unordered_map<std::string, array>;
// Possible types for a Primitive's state
using StateT = std::variant<
bool,
int,
size_t,
float,
double,
Dtype,
Shape,
Strides,
std::vector<int>,
std::vector<size_t>,
std::string>;
using ExportCallbackInput = std::unordered_map<
std::string,
std::variant<
std::vector<std::tuple<std::string, Shape, Dtype>>,
std::vector<std::pair<std::string, array>>,
std::vector<StateT>,
std::string>>;
using ExportCallback = std::function<void(const ExportCallbackInput&)>;

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/unordered_map.h>
@@ -133,50 +134,38 @@ auto wrap_export_function(nb::callable fun) {
void init_export(nb::module_& m) {
m.def(
"export_function",
[](const nb::callable& callback,
[](nb::object& file_or_callback,
const nb::callable& fun,
const nb::args& args,
bool shapeless,
const nb::kwargs& kwargs) {
auto [args_, kwargs_] =
validate_and_extract_inputs(args, kwargs, "[export_function]");
auto wrapped_callback =
[callback](const mx::ExportCallbackInput& input) {
return callback(input);
};
mx::export_function(
wrapped_callback,
wrap_export_function(fun),
args_,
kwargs_,
shapeless);
if (nb::isinstance<nb::str>(file_or_callback)) {
mx::export_function(
nb::cast<std::string>(file_or_callback),
wrap_export_function(fun),
args_,
kwargs_,
shapeless);
} else {
auto callback = nb::cast<nb::callable>(file_or_callback);
auto wrapped_callback =
[callback](const mx::ExportCallbackInput& input) {
return callback(input);
};
mx::export_function(
callback, wrap_export_function(fun), args_, kwargs_, shapeless);
}
},
"callback"_a,
"fun"_a,
"args"_a,
nb::kw_only(),
"shapeless"_a = false,
"kwargs"_a);
m.def(
"export_function",
[](const std::string& file,
const nb::callable& fun,
const nb::args& args,
bool shapeless,
const nb::kwargs& kwargs) {
auto [args_, kwargs_] =
validate_and_extract_inputs(args, kwargs, "[export_function]");
mx::export_function(
file, wrap_export_function(fun), args_, kwargs_, shapeless);
},
"file"_a,
nb::arg(),
"fun"_a,
"args"_a,
nb::kw_only(),
"shapeless"_a = false,
"kwargs"_a,
R"pbdoc(
Export a function to a file.
Export an MLX function.
Example input arrays must be provided to export a function. The example
inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays
@@ -189,7 +178,8 @@ void init_export(nb::module_& m) {
versions of MLX may not be compatible with future versions.
Args:
file (str): File path to export the function to.
file (str or Callable): Either a file path to export the function
to or a callback.
fun (Callable): A function which takes as input zero or more
:class:`array` and returns one or more :class:`array`.
*args (array): Example array inputs to the function.