diff --git a/mlx/export.cpp b/mlx/export.cpp index 94a99ce8e..d7f714c2e 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -14,6 +14,7 @@ #primitive, { \ serialize_primitive, \ deserialize_primitive, \ + primitive_state, \ {__VA_ARGS__} \ } \ } @@ -35,15 +36,20 @@ struct PrimitiveSerializer { using Serializer = std::function; using Deserializer = std::function(Reader&, Stream s)>; + using StateExtractor = std::function(const Primitive&)>; + PrimitiveSerializer( Serializer serialize, Deserializer deserialize, + StateExtractor extract_state, std::vector keys = {}) : serialize(std::move(serialize)), deserialize(std::move(deserialize)), + extract_state(std::move(extract_state)), keys(std::move(keys)) {}; Serializer serialize; Deserializer deserialize; + StateExtractor extract_state; std::vector keys; }; @@ -199,25 +205,26 @@ void serialize_primitive(Writer& os, const Primitive& p) { } } -// Possible types for a Primitive's state -using StateT = std::variant>; - template -void extract_state(T v, std::vector& state) { +void extract_state(const T state, std::vector& unpacked_state) { if constexpr (std::is_arithmetic_v) { - state.push_back(v); + unpacked_state.push_back(state); } else if constexpr (std::is_enum_v) { - state.push_back(static_cast(v)); + unpacked_state.push_back(static_cast(state)); + } else if constexpr (is_iterable) { + unpacked_state.push_back(state); + } else if constexpr (is_pair || is_tuple) { + std::apply( + [&unpacked_state](auto&... x) { + (..., extract_state(x, unpacked_state)); + }, + state); } - // } else if constexpr (is_iterable) { - // state.push_back(v); - // } else if constexpr (is_pair || is_tuple) { - // std::apply([&os](auto&... x) { (..., extract_state(os, state)); }, v); - // } } +// std::vector extract_state(const Primitive& p) { template -std::vector extract_state(const Primitive& p) { +std::vector primitive_state(const Primitive& p) { std::vector state; if constexpr (has_state) { extract_state(static_cast(p).state(), state); @@ -410,6 +417,23 @@ struct PrimitiveFactory { "[import_function] Unable to deserialize primitive " + name); } }; + + std::pair> extract_state( + const std::shared_ptr& p) { + std::string name = p->name(); + name = name.substr(0, name.find(' ')); + if (auto it = name_remap.find(name); it != name_remap.end()) { + name = it->second; + } + + if (auto it = factory.find(name); it != factory.end()) { + auto state = it->second.extract_state(*p); + return {name, state}; + } else { + throw std::invalid_argument( + "[export_function] Unable to get state for primitive " + name); + } + }; }; void write_header(Writer& os, int count, bool shapeless) { @@ -609,26 +633,30 @@ void FunctionExporter::export_with_callback( for (auto& in : inputs) { input_set.insert(in.id()); } - std::vector> constants; + std::vector> new_constants; for (auto& arr : tape) { if (arr.has_primitive() || input_set.find(arr.id()) != input_set.end()) { continue; } - constants.emplace_back(namer.get_name(arr), arr); + if (constants.insert(arr.id()).second) { + new_constants.emplace_back(namer.get_name(arr), arr); + } } - callback({{"constants", constants}}); + callback({{"constants", new_constants}}); } + auto factory = PrimitiveFactory(); + // Callback for each primitive in the tape for (auto& arr : tape) { if (!arr.has_primitive()) { continue; } - callback({ - {"inputs", to_vector_data(arr.inputs())}, - {"outputs", to_vector_data(arr.outputs())}, - {"name", arr.primitive().name()} - ////// {"state": []}); - }); + auto [name, state] = factory.extract_state(arr.primitive_ptr()); + callback( + {{"inputs", to_vector_data(arr.inputs())}, + {"outputs", to_vector_data(arr.outputs())}, + {"primitive", name}, + {"state", state}}); } } @@ -676,7 +704,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3); // Update the table entry - fentry.kwarg_keys = std::move(kwarg_keys); + fentry.kwarg_keys = kwarg_keys; fentry.inputs = trace_inputs; count++; diff --git a/mlx/export.h b/mlx/export.h index c8be4e7e7..934e5471c 100644 --- a/mlx/export.h +++ b/mlx/export.h @@ -4,17 +4,34 @@ #include #include +#include #include "mlx/array.h" namespace mlx::core { using Args = std::vector; using Kwargs = std::unordered_map; + +// Possible types for a Primitive's state +using StateT = std::variant< + bool, + int, + size_t, + float, + double, + Dtype, + Shape, + Strides, + std::vector, + std::vector, + std::string>; + using ExportCallbackInput = std::unordered_map< std::string, std::variant< std::vector>, std::vector>, + std::vector, std::string>>; using ExportCallback = std::function; diff --git a/python/src/export.cpp b/python/src/export.cpp index d62317103..3d3ffdf6b 100644 --- a/python/src/export.cpp +++ b/python/src/export.cpp @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. #include #include +#include #include #include #include @@ -133,50 +134,38 @@ auto wrap_export_function(nb::callable fun) { void init_export(nb::module_& m) { m.def( "export_function", - [](const nb::callable& callback, + [](nb::object& file_or_callback, const nb::callable& fun, const nb::args& args, bool shapeless, const nb::kwargs& kwargs) { auto [args_, kwargs_] = validate_and_extract_inputs(args, kwargs, "[export_function]"); - auto wrapped_callback = - [callback](const mx::ExportCallbackInput& input) { - return callback(input); - }; - mx::export_function( - wrapped_callback, - wrap_export_function(fun), - args_, - kwargs_, - shapeless); + if (nb::isinstance(file_or_callback)) { + mx::export_function( + nb::cast(file_or_callback), + wrap_export_function(fun), + args_, + kwargs_, + shapeless); + } else { + auto callback = nb::cast(file_or_callback); + auto wrapped_callback = + [callback](const mx::ExportCallbackInput& input) { + return callback(input); + }; + mx::export_function( + callback, wrap_export_function(fun), args_, kwargs_, shapeless); + } }, - "callback"_a, - "fun"_a, - "args"_a, - nb::kw_only(), - "shapeless"_a = false, - "kwargs"_a); - m.def( - "export_function", - [](const std::string& file, - const nb::callable& fun, - const nb::args& args, - bool shapeless, - const nb::kwargs& kwargs) { - auto [args_, kwargs_] = - validate_and_extract_inputs(args, kwargs, "[export_function]"); - mx::export_function( - file, wrap_export_function(fun), args_, kwargs_, shapeless); - }, - "file"_a, + nb::arg(), "fun"_a, "args"_a, nb::kw_only(), "shapeless"_a = false, "kwargs"_a, R"pbdoc( - Export a function to a file. + Export an MLX function. Example input arrays must be provided to export a function. The example inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays @@ -189,7 +178,8 @@ void init_export(nb::module_& m) { versions of MLX may not be compatible with future versions. Args: - file (str): File path to export the function to. + file (str or Callable): Either a file path to export the function + to or a callback. fun (Callable): A function which takes as input zero or more :class:`array` and returns one or more :class:`array`. *args (array): Example array inputs to the function.