diff --git a/mlx/export.cpp b/mlx/export.cpp index 19944dfc4..94a99ce8e 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -3,6 +3,7 @@ #include #include "mlx/compile_impl.h" #include "mlx/fast_primitives.h" +#include "mlx/graph_utils.h" #include "mlx/primitives.h" #include "mlx/utils.h" #include "mlx/version.h" @@ -198,6 +199,32 @@ void serialize_primitive(Writer& os, const Primitive& p) { } } +// Possible types for a Primitive's state +using StateT = std::variant>; + +template +void extract_state(T v, std::vector& state) { + if constexpr (std::is_arithmetic_v) { + state.push_back(v); + } else if constexpr (std::is_enum_v) { + state.push_back(static_cast(v)); + } + // } else if constexpr (is_iterable) { + // state.push_back(v); + // } else if constexpr (is_pair || is_tuple) { + // std::apply([&os](auto&... x) { (..., extract_state(os, state)); }, v); + // } +} + +template +std::vector extract_state(const Primitive& p) { + std::vector state; + if constexpr (has_state) { + extract_state(static_cast(p).state(), state); + } + return state; +} + template std::shared_ptr deserialize_primitive(Reader& is, Stream s) { if constexpr (has_state) { @@ -545,10 +572,66 @@ FunctionExporter::FunctionExporter( write_header(os, count, shapeless); } +FunctionExporter::FunctionExporter( + const ExportCallback& callback, + std::function(const Args&, const Kwargs&)> fun, + bool shapeless) + : callback(callback), + fun(std::move(fun)), + ftable(std::make_shared(shapeless)) {} + void FunctionExporter::close() { closed = true; }; +void FunctionExporter::export_with_callback( + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape) { + NodeNamer namer{}; + auto to_vector_data = [&namer](const auto& arrays) { + std::vector> data; + for (auto& a : arrays) { + data.emplace_back(namer.get_name(a), a.shape(), a.dtype()); + } + return data; + }; + + // Callback on the inputs + callback({{"inputs", to_vector_data(inputs)}}); + + // Callback on the outputs + callback({{"outputs", to_vector_data(outputs)}}); + + // Callback on the constants + { + std::unordered_set input_set; + for (auto& in : inputs) { + input_set.insert(in.id()); + } + std::vector> constants; + for (auto& arr : tape) { + if (arr.has_primitive() || input_set.find(arr.id()) != input_set.end()) { + continue; + } + constants.emplace_back(namer.get_name(arr), arr); + } + callback({{"constants", constants}}); + } + + for (auto& arr : tape) { + if (!arr.has_primitive()) { + continue; + } + callback({ + {"inputs", to_vector_data(arr.inputs())}, + {"outputs", to_vector_data(arr.outputs())}, + {"name", arr.primitive().name()} + ////// {"state": []}); + }); + } +} + void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { if (closed) { throw std::runtime_error( @@ -592,10 +675,18 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3); - // Update header + // Update the table entry + fentry.kwarg_keys = std::move(kwarg_keys); + fentry.inputs = trace_inputs; + count++; - // Overwrite the header + if (callback) { + export_with_callback(trace_inputs, trace_outputs, tape); + return; + } + + // Update the header auto pos = os.tell(); os.seek(0); write_header(os, count, ftable->shapeless); @@ -616,10 +707,6 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { serialize(os, trace_inputs); serialize(os, arrays_to_ids(trace_outputs)); - // Update the table entry - fentry.kwarg_keys = std::move(kwarg_keys); - fentry.inputs = std::move(trace_inputs); - std::unordered_set input_set( trace_input_ids.begin(), trace_input_ids.end()); @@ -730,6 +817,58 @@ void export_function( exporter(file, fun, shapeless)(args, kwargs); } +FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Args&)>& fun, + bool shapeless /* = false */) { + return FunctionExporter{ + callback, + [fun](const Args& args, const Kwargs&) { return fun(args); }, + shapeless}; +} + +FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Kwargs&)>& fun, + bool shapeless /* = false */) { + return exporter( + callback, + [fun](const Args&, const Kwargs kwargs) { return fun(kwargs); }, + shapeless); +} + +FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Args&, const Kwargs&)>& fun, + bool shapeless /* = false */) { + return FunctionExporter{callback, fun, shapeless}; +} + +void export_function( + const ExportCallback& callback, + const std::function(const Args&)>& fun, + const Args& args, + bool shapeless /* = false */) { + exporter(callback, fun, shapeless)(args); +} + +void export_function( + const ExportCallback& callback, + const std::function(const Kwargs&)>& fun, + const Kwargs& kwargs, + bool shapeless /* = false */) { + exporter(callback, fun, shapeless)(kwargs); +} + +void export_function( + const ExportCallback& callback, + const std::function(const Args&, const Kwargs&)>& fun, + const Args& args, + const Kwargs& kwargs, + bool shapeless /* = false */) { + exporter(callback, fun, shapeless)(args, kwargs); +} + std::vector ImportedFunction::operator()(const Kwargs& kwargs) const { return this->operator()({}, kwargs); } diff --git a/mlx/export.h b/mlx/export.h index c6859c6d8..c8be4e7e7 100644 --- a/mlx/export.h +++ b/mlx/export.h @@ -10,6 +10,13 @@ namespace mlx::core { using Args = std::vector; using Kwargs = std::unordered_map; +using ExportCallbackInput = std::unordered_map< + std::string, + std::variant< + std::vector>, + std::vector>, + std::string>>; +using ExportCallback = std::function; struct FunctionExporter; @@ -61,6 +68,47 @@ struct ImportedFunction; */ ImportedFunction import_function(const std::string& file); +/** + * Make an exporter to export multiple traces of a given function with the same + * callback. + */ +FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Args&)>& fun, + bool shapeless = false); + +FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Kwargs&)>& fun, + bool shapeless = false); + +FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Args&, const Kwargs&)>& fun, + bool shapeless = false); + +/** + * Export a function with a callback. + */ +void export_function( + const ExportCallback& callback, + const std::function(const Args&)>& fun, + const Args& args, + bool shapeless = false); + +void export_function( + const ExportCallback& callback, + const std::function(const Kwargs&)>& fun, + const Kwargs& kwargs, + bool shapeless = false); + +void export_function( + const ExportCallback& callback, + const std::function(const Args&, const Kwargs&)>& fun, + const Args& args, + const Kwargs& kwargs, + bool shapeless = false); + } // namespace mlx::core #include "mlx/export_impl.h" diff --git a/mlx/export_impl.h b/mlx/export_impl.h index 74a1c35c1..82756c23f 100644 --- a/mlx/export_impl.h +++ b/mlx/export_impl.h @@ -38,13 +38,39 @@ struct FunctionExporter { const std::function(const Args&, const Kwargs&)>&, bool shapeless); + friend FunctionExporter exporter( + const ExportCallback&, + const std::function(const Args&)>&, + bool shapeless); + + friend FunctionExporter exporter( + const ExportCallback&, + const std::function(const Kwargs&)>&, + bool shapeless); + + friend FunctionExporter exporter( + const ExportCallback&, + const std::function(const Args&, const Kwargs&)>&, + bool shapeless); + FunctionExporter( const std::string& file, std::function(const Args&, const Kwargs&)> fun, bool shapeless); + + FunctionExporter( + const ExportCallback& callback, + std::function(const Args&, const Kwargs&)> fun, + bool shapeless); + io::FileWriter os; + ExportCallback callback; std::function(const Args&, const Kwargs& kwargs)> fun; void export_function(const Args& args, const Kwargs& kwargs); + void export_with_callback( + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape); std::set constants; int count{0}; bool closed{false}; diff --git a/mlx/io/load.h b/mlx/io/load.h index 8b5dd95b6..0efcb367b 100644 --- a/mlx/io/load.h +++ b/mlx/io/load.h @@ -108,6 +108,7 @@ class ParallelFileReader : public Reader { class FileWriter : public Writer { public: + explicit FileWriter() {} explicit FileWriter(std::string file_path) : fd_(open( file_path.c_str(), diff --git a/python/src/export.cpp b/python/src/export.cpp index 4428e7cc8..d62317103 100644 --- a/python/src/export.cpp +++ b/python/src/export.cpp @@ -2,7 +2,9 @@ #include #include #include +#include #include +#include #include #include @@ -129,6 +131,32 @@ auto wrap_export_function(nb::callable fun) { } void init_export(nb::module_& m) { + m.def( + "export_function", + [](const nb::callable& callback, + const nb::callable& fun, + const nb::args& args, + bool shapeless, + const nb::kwargs& kwargs) { + auto [args_, kwargs_] = + validate_and_extract_inputs(args, kwargs, "[export_function]"); + auto wrapped_callback = + [callback](const mx::ExportCallbackInput& input) { + return callback(input); + }; + mx::export_function( + wrapped_callback, + wrap_export_function(fun), + args_, + kwargs_, + shapeless); + }, + "callback"_a, + "fun"_a, + "args"_a, + nb::kw_only(), + "shapeless"_a = false, + "kwargs"_a); m.def( "export_function", [](const std::string& file,