mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
export with callback
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
#primitive, { \
|
||||
serialize_primitive<primitive>, \
|
||||
deserialize_primitive<primitive>, \
|
||||
primitive_state<primitive>, \
|
||||
{__VA_ARGS__} \
|
||||
} \
|
||||
}
|
||||
@@ -35,15 +36,20 @@ struct PrimitiveSerializer {
|
||||
using Serializer = std::function<void(Writer&, const Primitive&)>;
|
||||
using Deserializer =
|
||||
std::function<std::shared_ptr<Primitive>(Reader&, Stream s)>;
|
||||
using StateExtractor = std::function<std::vector<StateT>(const Primitive&)>;
|
||||
|
||||
PrimitiveSerializer(
|
||||
Serializer serialize,
|
||||
Deserializer deserialize,
|
||||
StateExtractor extract_state,
|
||||
std::vector<std::string> keys = {})
|
||||
: serialize(std::move(serialize)),
|
||||
deserialize(std::move(deserialize)),
|
||||
extract_state(std::move(extract_state)),
|
||||
keys(std::move(keys)) {};
|
||||
Serializer serialize;
|
||||
Deserializer deserialize;
|
||||
StateExtractor extract_state;
|
||||
std::vector<std::string> keys;
|
||||
};
|
||||
|
||||
@@ -199,25 +205,26 @@ void serialize_primitive(Writer& os, const Primitive& p) {
|
||||
}
|
||||
}
|
||||
|
||||
// Possible types for a Primitive's state
|
||||
using StateT = std::variant<int, float, std::vector<int>>;
|
||||
|
||||
template <typename T>
|
||||
void extract_state(T v, std::vector<StateT>& state) {
|
||||
void extract_state(const T state, std::vector<StateT>& unpacked_state) {
|
||||
if constexpr (std::is_arithmetic_v<T>) {
|
||||
state.push_back(v);
|
||||
unpacked_state.push_back(state);
|
||||
} else if constexpr (std::is_enum_v<T>) {
|
||||
state.push_back(static_cast<int>(v));
|
||||
unpacked_state.push_back(static_cast<int>(state));
|
||||
} else if constexpr (is_iterable<T>) {
|
||||
unpacked_state.push_back(state);
|
||||
} else if constexpr (is_pair<T> || is_tuple<T>) {
|
||||
std::apply(
|
||||
[&unpacked_state](auto&... x) {
|
||||
(..., extract_state(x, unpacked_state));
|
||||
},
|
||||
state);
|
||||
}
|
||||
// } else if constexpr (is_iterable<T>) {
|
||||
// state.push_back(v);
|
||||
// } else if constexpr (is_pair<T> || is_tuple<T>) {
|
||||
// std::apply([&os](auto&... x) { (..., extract_state(os, state)); }, v);
|
||||
// }
|
||||
}
|
||||
|
||||
// std::vector<StateT> extract_state(const Primitive& p) {
|
||||
template <typename T>
|
||||
std::vector<StateT> extract_state(const Primitive& p) {
|
||||
std::vector<StateT> primitive_state(const Primitive& p) {
|
||||
std::vector<StateT> state;
|
||||
if constexpr (has_state<T>) {
|
||||
extract_state(static_cast<const T&>(p).state(), state);
|
||||
@@ -410,6 +417,23 @@ struct PrimitiveFactory {
|
||||
"[import_function] Unable to deserialize primitive " + name);
|
||||
}
|
||||
};
|
||||
|
||||
std::pair<std::string, std::vector<StateT>> extract_state(
|
||||
const std::shared_ptr<Primitive>& p) {
|
||||
std::string name = p->name();
|
||||
name = name.substr(0, name.find(' '));
|
||||
if (auto it = name_remap.find(name); it != name_remap.end()) {
|
||||
name = it->second;
|
||||
}
|
||||
|
||||
if (auto it = factory.find(name); it != factory.end()) {
|
||||
auto state = it->second.extract_state(*p);
|
||||
return {name, state};
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[export_function] Unable to get state for primitive " + name);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
void write_header(Writer& os, int count, bool shapeless) {
|
||||
@@ -609,26 +633,30 @@ void FunctionExporter::export_with_callback(
|
||||
for (auto& in : inputs) {
|
||||
input_set.insert(in.id());
|
||||
}
|
||||
std::vector<std::pair<std::string, array>> constants;
|
||||
std::vector<std::pair<std::string, array>> new_constants;
|
||||
for (auto& arr : tape) {
|
||||
if (arr.has_primitive() || input_set.find(arr.id()) != input_set.end()) {
|
||||
continue;
|
||||
}
|
||||
constants.emplace_back(namer.get_name(arr), arr);
|
||||
if (constants.insert(arr.id()).second) {
|
||||
new_constants.emplace_back(namer.get_name(arr), arr);
|
||||
}
|
||||
}
|
||||
callback({{"constants", constants}});
|
||||
callback({{"constants", new_constants}});
|
||||
}
|
||||
auto factory = PrimitiveFactory();
|
||||
|
||||
// Callback for each primitive in the tape
|
||||
for (auto& arr : tape) {
|
||||
if (!arr.has_primitive()) {
|
||||
continue;
|
||||
}
|
||||
callback({
|
||||
{"inputs", to_vector_data(arr.inputs())},
|
||||
{"outputs", to_vector_data(arr.outputs())},
|
||||
{"name", arr.primitive().name()}
|
||||
////// {"state": []});
|
||||
});
|
||||
auto [name, state] = factory.extract_state(arr.primitive_ptr());
|
||||
callback(
|
||||
{{"inputs", to_vector_data(arr.inputs())},
|
||||
{"outputs", to_vector_data(arr.outputs())},
|
||||
{"primitive", name},
|
||||
{"state", state}});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -676,7 +704,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
||||
detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3);
|
||||
|
||||
// Update the table entry
|
||||
fentry.kwarg_keys = std::move(kwarg_keys);
|
||||
fentry.kwarg_keys = kwarg_keys;
|
||||
fentry.inputs = trace_inputs;
|
||||
|
||||
count++;
|
||||
|
||||
17
mlx/export.h
17
mlx/export.h
@@ -4,17 +4,34 @@
|
||||
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <variant>
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
using Args = std::vector<array>;
|
||||
using Kwargs = std::unordered_map<std::string, array>;
|
||||
|
||||
// Possible types for a Primitive's state
|
||||
using StateT = std::variant<
|
||||
bool,
|
||||
int,
|
||||
size_t,
|
||||
float,
|
||||
double,
|
||||
Dtype,
|
||||
Shape,
|
||||
Strides,
|
||||
std::vector<int>,
|
||||
std::vector<size_t>,
|
||||
std::string>;
|
||||
|
||||
using ExportCallbackInput = std::unordered_map<
|
||||
std::string,
|
||||
std::variant<
|
||||
std::vector<std::tuple<std::string, Shape, Dtype>>,
|
||||
std::vector<std::pair<std::string, array>>,
|
||||
std::vector<StateT>,
|
||||
std::string>>;
|
||||
using ExportCallback = std::function<void(const ExportCallbackInput&)>;
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/pair.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/tuple.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
@@ -133,50 +134,38 @@ auto wrap_export_function(nb::callable fun) {
|
||||
void init_export(nb::module_& m) {
|
||||
m.def(
|
||||
"export_function",
|
||||
[](const nb::callable& callback,
|
||||
[](nb::object& file_or_callback,
|
||||
const nb::callable& fun,
|
||||
const nb::args& args,
|
||||
bool shapeless,
|
||||
const nb::kwargs& kwargs) {
|
||||
auto [args_, kwargs_] =
|
||||
validate_and_extract_inputs(args, kwargs, "[export_function]");
|
||||
auto wrapped_callback =
|
||||
[callback](const mx::ExportCallbackInput& input) {
|
||||
return callback(input);
|
||||
};
|
||||
mx::export_function(
|
||||
wrapped_callback,
|
||||
wrap_export_function(fun),
|
||||
args_,
|
||||
kwargs_,
|
||||
shapeless);
|
||||
if (nb::isinstance<nb::str>(file_or_callback)) {
|
||||
mx::export_function(
|
||||
nb::cast<std::string>(file_or_callback),
|
||||
wrap_export_function(fun),
|
||||
args_,
|
||||
kwargs_,
|
||||
shapeless);
|
||||
} else {
|
||||
auto callback = nb::cast<nb::callable>(file_or_callback);
|
||||
auto wrapped_callback =
|
||||
[callback](const mx::ExportCallbackInput& input) {
|
||||
return callback(input);
|
||||
};
|
||||
mx::export_function(
|
||||
callback, wrap_export_function(fun), args_, kwargs_, shapeless);
|
||||
}
|
||||
},
|
||||
"callback"_a,
|
||||
"fun"_a,
|
||||
"args"_a,
|
||||
nb::kw_only(),
|
||||
"shapeless"_a = false,
|
||||
"kwargs"_a);
|
||||
m.def(
|
||||
"export_function",
|
||||
[](const std::string& file,
|
||||
const nb::callable& fun,
|
||||
const nb::args& args,
|
||||
bool shapeless,
|
||||
const nb::kwargs& kwargs) {
|
||||
auto [args_, kwargs_] =
|
||||
validate_and_extract_inputs(args, kwargs, "[export_function]");
|
||||
mx::export_function(
|
||||
file, wrap_export_function(fun), args_, kwargs_, shapeless);
|
||||
},
|
||||
"file"_a,
|
||||
nb::arg(),
|
||||
"fun"_a,
|
||||
"args"_a,
|
||||
nb::kw_only(),
|
||||
"shapeless"_a = false,
|
||||
"kwargs"_a,
|
||||
R"pbdoc(
|
||||
Export a function to a file.
|
||||
Export an MLX function.
|
||||
|
||||
Example input arrays must be provided to export a function. The example
|
||||
inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays
|
||||
@@ -189,7 +178,8 @@ void init_export(nb::module_& m) {
|
||||
versions of MLX may not be compatible with future versions.
|
||||
|
||||
Args:
|
||||
file (str): File path to export the function to.
|
||||
file (str or Callable): Either a file path to export the function
|
||||
to or a callback.
|
||||
fun (Callable): A function which takes as input zero or more
|
||||
:class:`array` and returns one or more :class:`array`.
|
||||
*args (array): Example array inputs to the function.
|
||||
|
||||
Reference in New Issue
Block a user