mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
export with callback
This commit is contained in:
@@ -14,6 +14,7 @@
|
|||||||
#primitive, { \
|
#primitive, { \
|
||||||
serialize_primitive<primitive>, \
|
serialize_primitive<primitive>, \
|
||||||
deserialize_primitive<primitive>, \
|
deserialize_primitive<primitive>, \
|
||||||
|
primitive_state<primitive>, \
|
||||||
{__VA_ARGS__} \
|
{__VA_ARGS__} \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
@@ -35,15 +36,20 @@ struct PrimitiveSerializer {
|
|||||||
using Serializer = std::function<void(Writer&, const Primitive&)>;
|
using Serializer = std::function<void(Writer&, const Primitive&)>;
|
||||||
using Deserializer =
|
using Deserializer =
|
||||||
std::function<std::shared_ptr<Primitive>(Reader&, Stream s)>;
|
std::function<std::shared_ptr<Primitive>(Reader&, Stream s)>;
|
||||||
|
using StateExtractor = std::function<std::vector<StateT>(const Primitive&)>;
|
||||||
|
|
||||||
PrimitiveSerializer(
|
PrimitiveSerializer(
|
||||||
Serializer serialize,
|
Serializer serialize,
|
||||||
Deserializer deserialize,
|
Deserializer deserialize,
|
||||||
|
StateExtractor extract_state,
|
||||||
std::vector<std::string> keys = {})
|
std::vector<std::string> keys = {})
|
||||||
: serialize(std::move(serialize)),
|
: serialize(std::move(serialize)),
|
||||||
deserialize(std::move(deserialize)),
|
deserialize(std::move(deserialize)),
|
||||||
|
extract_state(std::move(extract_state)),
|
||||||
keys(std::move(keys)) {};
|
keys(std::move(keys)) {};
|
||||||
Serializer serialize;
|
Serializer serialize;
|
||||||
Deserializer deserialize;
|
Deserializer deserialize;
|
||||||
|
StateExtractor extract_state;
|
||||||
std::vector<std::string> keys;
|
std::vector<std::string> keys;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -199,25 +205,26 @@ void serialize_primitive(Writer& os, const Primitive& p) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Possible types for a Primitive's state
|
|
||||||
using StateT = std::variant<int, float, std::vector<int>>;
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void extract_state(T v, std::vector<StateT>& state) {
|
void extract_state(const T state, std::vector<StateT>& unpacked_state) {
|
||||||
if constexpr (std::is_arithmetic_v<T>) {
|
if constexpr (std::is_arithmetic_v<T>) {
|
||||||
state.push_back(v);
|
unpacked_state.push_back(state);
|
||||||
} else if constexpr (std::is_enum_v<T>) {
|
} else if constexpr (std::is_enum_v<T>) {
|
||||||
state.push_back(static_cast<int>(v));
|
unpacked_state.push_back(static_cast<int>(state));
|
||||||
|
} else if constexpr (is_iterable<T>) {
|
||||||
|
unpacked_state.push_back(state);
|
||||||
|
} else if constexpr (is_pair<T> || is_tuple<T>) {
|
||||||
|
std::apply(
|
||||||
|
[&unpacked_state](auto&... x) {
|
||||||
|
(..., extract_state(x, unpacked_state));
|
||||||
|
},
|
||||||
|
state);
|
||||||
}
|
}
|
||||||
// } else if constexpr (is_iterable<T>) {
|
|
||||||
// state.push_back(v);
|
|
||||||
// } else if constexpr (is_pair<T> || is_tuple<T>) {
|
|
||||||
// std::apply([&os](auto&... x) { (..., extract_state(os, state)); }, v);
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// std::vector<StateT> extract_state(const Primitive& p) {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<StateT> extract_state(const Primitive& p) {
|
std::vector<StateT> primitive_state(const Primitive& p) {
|
||||||
std::vector<StateT> state;
|
std::vector<StateT> state;
|
||||||
if constexpr (has_state<T>) {
|
if constexpr (has_state<T>) {
|
||||||
extract_state(static_cast<const T&>(p).state(), state);
|
extract_state(static_cast<const T&>(p).state(), state);
|
||||||
@@ -410,6 +417,23 @@ struct PrimitiveFactory {
|
|||||||
"[import_function] Unable to deserialize primitive " + name);
|
"[import_function] Unable to deserialize primitive " + name);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::pair<std::string, std::vector<StateT>> extract_state(
|
||||||
|
const std::shared_ptr<Primitive>& p) {
|
||||||
|
std::string name = p->name();
|
||||||
|
name = name.substr(0, name.find(' '));
|
||||||
|
if (auto it = name_remap.find(name); it != name_remap.end()) {
|
||||||
|
name = it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto it = factory.find(name); it != factory.end()) {
|
||||||
|
auto state = it->second.extract_state(*p);
|
||||||
|
return {name, state};
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[export_function] Unable to get state for primitive " + name);
|
||||||
|
}
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
void write_header(Writer& os, int count, bool shapeless) {
|
void write_header(Writer& os, int count, bool shapeless) {
|
||||||
@@ -609,26 +633,30 @@ void FunctionExporter::export_with_callback(
|
|||||||
for (auto& in : inputs) {
|
for (auto& in : inputs) {
|
||||||
input_set.insert(in.id());
|
input_set.insert(in.id());
|
||||||
}
|
}
|
||||||
std::vector<std::pair<std::string, array>> constants;
|
std::vector<std::pair<std::string, array>> new_constants;
|
||||||
for (auto& arr : tape) {
|
for (auto& arr : tape) {
|
||||||
if (arr.has_primitive() || input_set.find(arr.id()) != input_set.end()) {
|
if (arr.has_primitive() || input_set.find(arr.id()) != input_set.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
constants.emplace_back(namer.get_name(arr), arr);
|
if (constants.insert(arr.id()).second) {
|
||||||
|
new_constants.emplace_back(namer.get_name(arr), arr);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
callback({{"constants", constants}});
|
callback({{"constants", new_constants}});
|
||||||
}
|
}
|
||||||
|
auto factory = PrimitiveFactory();
|
||||||
|
|
||||||
|
// Callback for each primitive in the tape
|
||||||
for (auto& arr : tape) {
|
for (auto& arr : tape) {
|
||||||
if (!arr.has_primitive()) {
|
if (!arr.has_primitive()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
callback({
|
auto [name, state] = factory.extract_state(arr.primitive_ptr());
|
||||||
{"inputs", to_vector_data(arr.inputs())},
|
callback(
|
||||||
{"outputs", to_vector_data(arr.outputs())},
|
{{"inputs", to_vector_data(arr.inputs())},
|
||||||
{"name", arr.primitive().name()}
|
{"outputs", to_vector_data(arr.outputs())},
|
||||||
////// {"state": []});
|
{"primitive", name},
|
||||||
});
|
{"state", state}});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -676,7 +704,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
|||||||
detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3);
|
detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3);
|
||||||
|
|
||||||
// Update the table entry
|
// Update the table entry
|
||||||
fentry.kwarg_keys = std::move(kwarg_keys);
|
fentry.kwarg_keys = kwarg_keys;
|
||||||
fentry.inputs = trace_inputs;
|
fentry.inputs = trace_inputs;
|
||||||
|
|
||||||
count++;
|
count++;
|
||||||
|
|||||||
17
mlx/export.h
17
mlx/export.h
@@ -4,17 +4,34 @@
|
|||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <variant>
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
using Args = std::vector<array>;
|
using Args = std::vector<array>;
|
||||||
using Kwargs = std::unordered_map<std::string, array>;
|
using Kwargs = std::unordered_map<std::string, array>;
|
||||||
|
|
||||||
|
// Possible types for a Primitive's state
|
||||||
|
using StateT = std::variant<
|
||||||
|
bool,
|
||||||
|
int,
|
||||||
|
size_t,
|
||||||
|
float,
|
||||||
|
double,
|
||||||
|
Dtype,
|
||||||
|
Shape,
|
||||||
|
Strides,
|
||||||
|
std::vector<int>,
|
||||||
|
std::vector<size_t>,
|
||||||
|
std::string>;
|
||||||
|
|
||||||
using ExportCallbackInput = std::unordered_map<
|
using ExportCallbackInput = std::unordered_map<
|
||||||
std::string,
|
std::string,
|
||||||
std::variant<
|
std::variant<
|
||||||
std::vector<std::tuple<std::string, Shape, Dtype>>,
|
std::vector<std::tuple<std::string, Shape, Dtype>>,
|
||||||
std::vector<std::pair<std::string, array>>,
|
std::vector<std::pair<std::string, array>>,
|
||||||
|
std::vector<StateT>,
|
||||||
std::string>>;
|
std::string>>;
|
||||||
using ExportCallback = std::function<void(const ExportCallbackInput&)>;
|
using ExportCallback = std::function<void(const ExportCallbackInput&)>;
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include <nanobind/nanobind.h>
|
#include <nanobind/nanobind.h>
|
||||||
#include <nanobind/stl/optional.h>
|
#include <nanobind/stl/optional.h>
|
||||||
|
#include <nanobind/stl/pair.h>
|
||||||
#include <nanobind/stl/string.h>
|
#include <nanobind/stl/string.h>
|
||||||
#include <nanobind/stl/tuple.h>
|
#include <nanobind/stl/tuple.h>
|
||||||
#include <nanobind/stl/unordered_map.h>
|
#include <nanobind/stl/unordered_map.h>
|
||||||
@@ -133,50 +134,38 @@ auto wrap_export_function(nb::callable fun) {
|
|||||||
void init_export(nb::module_& m) {
|
void init_export(nb::module_& m) {
|
||||||
m.def(
|
m.def(
|
||||||
"export_function",
|
"export_function",
|
||||||
[](const nb::callable& callback,
|
[](nb::object& file_or_callback,
|
||||||
const nb::callable& fun,
|
const nb::callable& fun,
|
||||||
const nb::args& args,
|
const nb::args& args,
|
||||||
bool shapeless,
|
bool shapeless,
|
||||||
const nb::kwargs& kwargs) {
|
const nb::kwargs& kwargs) {
|
||||||
auto [args_, kwargs_] =
|
auto [args_, kwargs_] =
|
||||||
validate_and_extract_inputs(args, kwargs, "[export_function]");
|
validate_and_extract_inputs(args, kwargs, "[export_function]");
|
||||||
auto wrapped_callback =
|
if (nb::isinstance<nb::str>(file_or_callback)) {
|
||||||
[callback](const mx::ExportCallbackInput& input) {
|
mx::export_function(
|
||||||
return callback(input);
|
nb::cast<std::string>(file_or_callback),
|
||||||
};
|
wrap_export_function(fun),
|
||||||
mx::export_function(
|
args_,
|
||||||
wrapped_callback,
|
kwargs_,
|
||||||
wrap_export_function(fun),
|
shapeless);
|
||||||
args_,
|
} else {
|
||||||
kwargs_,
|
auto callback = nb::cast<nb::callable>(file_or_callback);
|
||||||
shapeless);
|
auto wrapped_callback =
|
||||||
|
[callback](const mx::ExportCallbackInput& input) {
|
||||||
|
return callback(input);
|
||||||
|
};
|
||||||
|
mx::export_function(
|
||||||
|
callback, wrap_export_function(fun), args_, kwargs_, shapeless);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"callback"_a,
|
nb::arg(),
|
||||||
"fun"_a,
|
|
||||||
"args"_a,
|
|
||||||
nb::kw_only(),
|
|
||||||
"shapeless"_a = false,
|
|
||||||
"kwargs"_a);
|
|
||||||
m.def(
|
|
||||||
"export_function",
|
|
||||||
[](const std::string& file,
|
|
||||||
const nb::callable& fun,
|
|
||||||
const nb::args& args,
|
|
||||||
bool shapeless,
|
|
||||||
const nb::kwargs& kwargs) {
|
|
||||||
auto [args_, kwargs_] =
|
|
||||||
validate_and_extract_inputs(args, kwargs, "[export_function]");
|
|
||||||
mx::export_function(
|
|
||||||
file, wrap_export_function(fun), args_, kwargs_, shapeless);
|
|
||||||
},
|
|
||||||
"file"_a,
|
|
||||||
"fun"_a,
|
"fun"_a,
|
||||||
"args"_a,
|
"args"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"shapeless"_a = false,
|
"shapeless"_a = false,
|
||||||
"kwargs"_a,
|
"kwargs"_a,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Export a function to a file.
|
Export an MLX function.
|
||||||
|
|
||||||
Example input arrays must be provided to export a function. The example
|
Example input arrays must be provided to export a function. The example
|
||||||
inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays
|
inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays
|
||||||
@@ -189,7 +178,8 @@ void init_export(nb::module_& m) {
|
|||||||
versions of MLX may not be compatible with future versions.
|
versions of MLX may not be compatible with future versions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file (str): File path to export the function to.
|
file (str or Callable): Either a file path to export the function
|
||||||
|
to or a callback.
|
||||||
fun (Callable): A function which takes as input zero or more
|
fun (Callable): A function which takes as input zero or more
|
||||||
:class:`array` and returns one or more :class:`array`.
|
:class:`array` and returns one or more :class:`array`.
|
||||||
*args (array): Example array inputs to the function.
|
*args (array): Example array inputs to the function.
|
||||||
|
|||||||
Reference in New Issue
Block a user