export with callback

This commit is contained in:
Awni Hannun
2025-08-04 11:48:36 -07:00
parent aa9d44b3d4
commit a95d4a74d9
5 changed files with 248 additions and 6 deletions

View File

@@ -3,6 +3,7 @@
#include <map> #include <map>
#include "mlx/compile_impl.h" #include "mlx/compile_impl.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
#include "mlx/version.h" #include "mlx/version.h"
@@ -198,6 +199,32 @@ void serialize_primitive(Writer& os, const Primitive& p) {
} }
} }
// Possible types for a Primitive's state
using StateT = std::variant<int, float, std::vector<int>>;
template <typename T>
void extract_state(T v, std::vector<StateT>& state) {
if constexpr (std::is_arithmetic_v<T>) {
state.push_back(v);
} else if constexpr (std::is_enum_v<T>) {
state.push_back(static_cast<int>(v));
}
// } else if constexpr (is_iterable<T>) {
// state.push_back(v);
// } else if constexpr (is_pair<T> || is_tuple<T>) {
// std::apply([&os](auto&... x) { (..., extract_state(os, state)); }, v);
// }
}
template <typename T>
std::vector<StateT> extract_state(const Primitive& p) {
std::vector<StateT> state;
if constexpr (has_state<T>) {
extract_state(static_cast<const T&>(p).state(), state);
}
return state;
}
template <typename T> template <typename T>
std::shared_ptr<T> deserialize_primitive(Reader& is, Stream s) { std::shared_ptr<T> deserialize_primitive(Reader& is, Stream s) {
if constexpr (has_state<T>) { if constexpr (has_state<T>) {
@@ -545,10 +572,66 @@ FunctionExporter::FunctionExporter(
write_header(os, count, shapeless); write_header(os, count, shapeless);
} }
FunctionExporter::FunctionExporter(
const ExportCallback& callback,
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
bool shapeless)
: callback(callback),
fun(std::move(fun)),
ftable(std::make_shared<FunctionTable>(shapeless)) {}
void FunctionExporter::close() { void FunctionExporter::close() {
closed = true; closed = true;
}; };
void FunctionExporter::export_with_callback(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape) {
NodeNamer namer{};
auto to_vector_data = [&namer](const auto& arrays) {
std::vector<std::tuple<std::string, Shape, Dtype>> data;
for (auto& a : arrays) {
data.emplace_back(namer.get_name(a), a.shape(), a.dtype());
}
return data;
};
// Callback on the inputs
callback({{"inputs", to_vector_data(inputs)}});
// Callback on the outputs
callback({{"outputs", to_vector_data(outputs)}});
// Callback on the constants
{
std::unordered_set<std::uintptr_t> input_set;
for (auto& in : inputs) {
input_set.insert(in.id());
}
std::vector<std::pair<std::string, array>> constants;
for (auto& arr : tape) {
if (arr.has_primitive() || input_set.find(arr.id()) != input_set.end()) {
continue;
}
constants.emplace_back(namer.get_name(arr), arr);
}
callback({{"constants", constants}});
}
for (auto& arr : tape) {
if (!arr.has_primitive()) {
continue;
}
callback({
{"inputs", to_vector_data(arr.inputs())},
{"outputs", to_vector_data(arr.outputs())},
{"name", arr.primitive().name()}
////// {"state": []});
});
}
}
void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
if (closed) { if (closed) {
throw std::runtime_error( throw std::runtime_error(
@@ -592,10 +675,18 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3); detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3);
// Update header // Update the table entry
fentry.kwarg_keys = std::move(kwarg_keys);
fentry.inputs = trace_inputs;
count++; count++;
// Overwrite the header if (callback) {
export_with_callback(trace_inputs, trace_outputs, tape);
return;
}
// Update the header
auto pos = os.tell(); auto pos = os.tell();
os.seek(0); os.seek(0);
write_header(os, count, ftable->shapeless); write_header(os, count, ftable->shapeless);
@@ -616,10 +707,6 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
serialize(os, trace_inputs); serialize(os, trace_inputs);
serialize(os, arrays_to_ids(trace_outputs)); serialize(os, arrays_to_ids(trace_outputs));
// Update the table entry
fentry.kwarg_keys = std::move(kwarg_keys);
fentry.inputs = std::move(trace_inputs);
std::unordered_set<std::uintptr_t> input_set( std::unordered_set<std::uintptr_t> input_set(
trace_input_ids.begin(), trace_input_ids.end()); trace_input_ids.begin(), trace_input_ids.end());
@@ -730,6 +817,58 @@ void export_function(
exporter(file, fun, shapeless)(args, kwargs); exporter(file, fun, shapeless)(args, kwargs);
} }
FunctionExporter exporter(
const ExportCallback& callback,
const std::function<std::vector<array>(const Args&)>& fun,
bool shapeless /* = false */) {
return FunctionExporter{
callback,
[fun](const Args& args, const Kwargs&) { return fun(args); },
shapeless};
}
FunctionExporter exporter(
const ExportCallback& callback,
const std::function<std::vector<array>(const Kwargs&)>& fun,
bool shapeless /* = false */) {
return exporter(
callback,
[fun](const Args&, const Kwargs kwargs) { return fun(kwargs); },
shapeless);
}
FunctionExporter exporter(
const ExportCallback& callback,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
bool shapeless /* = false */) {
return FunctionExporter{callback, fun, shapeless};
}
void export_function(
const ExportCallback& callback,
const std::function<std::vector<array>(const Args&)>& fun,
const Args& args,
bool shapeless /* = false */) {
exporter(callback, fun, shapeless)(args);
}
void export_function(
const ExportCallback& callback,
const std::function<std::vector<array>(const Kwargs&)>& fun,
const Kwargs& kwargs,
bool shapeless /* = false */) {
exporter(callback, fun, shapeless)(kwargs);
}
void export_function(
const ExportCallback& callback,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
const Args& args,
const Kwargs& kwargs,
bool shapeless /* = false */) {
exporter(callback, fun, shapeless)(args, kwargs);
}
std::vector<array> ImportedFunction::operator()(const Kwargs& kwargs) const { std::vector<array> ImportedFunction::operator()(const Kwargs& kwargs) const {
return this->operator()({}, kwargs); return this->operator()({}, kwargs);
} }

View File

@@ -10,6 +10,13 @@ namespace mlx::core {
using Args = std::vector<array>; using Args = std::vector<array>;
using Kwargs = std::unordered_map<std::string, array>; using Kwargs = std::unordered_map<std::string, array>;
using ExportCallbackInput = std::unordered_map<
std::string,
std::variant<
std::vector<std::tuple<std::string, Shape, Dtype>>,
std::vector<std::pair<std::string, array>>,
std::string>>;
using ExportCallback = std::function<void(const ExportCallbackInput&)>;
struct FunctionExporter; struct FunctionExporter;
@@ -61,6 +68,47 @@ struct ImportedFunction;
*/ */
ImportedFunction import_function(const std::string& file); ImportedFunction import_function(const std::string& file);
/**
* Make an exporter to export multiple traces of a given function with the same
* callback.
*/
FunctionExporter exporter(
const ExportCallback& callback,
const std::function<std::vector<array>(const Args&)>& fun,
bool shapeless = false);
FunctionExporter exporter(
const ExportCallback& callback,
const std::function<std::vector<array>(const Kwargs&)>& fun,
bool shapeless = false);
FunctionExporter exporter(
const ExportCallback& callback,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
bool shapeless = false);
/**
* Export a function with a callback.
*/
void export_function(
const ExportCallback& callback,
const std::function<std::vector<array>(const Args&)>& fun,
const Args& args,
bool shapeless = false);
void export_function(
const ExportCallback& callback,
const std::function<std::vector<array>(const Kwargs&)>& fun,
const Kwargs& kwargs,
bool shapeless = false);
void export_function(
const ExportCallback& callback,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
const Args& args,
const Kwargs& kwargs,
bool shapeless = false);
} // namespace mlx::core } // namespace mlx::core
#include "mlx/export_impl.h" #include "mlx/export_impl.h"

View File

@@ -38,13 +38,39 @@ struct FunctionExporter {
const std::function<std::vector<array>(const Args&, const Kwargs&)>&, const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
bool shapeless); bool shapeless);
friend FunctionExporter exporter(
const ExportCallback&,
const std::function<std::vector<array>(const Args&)>&,
bool shapeless);
friend FunctionExporter exporter(
const ExportCallback&,
const std::function<std::vector<array>(const Kwargs&)>&,
bool shapeless);
friend FunctionExporter exporter(
const ExportCallback&,
const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
bool shapeless);
FunctionExporter( FunctionExporter(
const std::string& file, const std::string& file,
std::function<std::vector<array>(const Args&, const Kwargs&)> fun, std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
bool shapeless); bool shapeless);
FunctionExporter(
const ExportCallback& callback,
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
bool shapeless);
io::FileWriter os; io::FileWriter os;
ExportCallback callback;
std::function<std::vector<array>(const Args&, const Kwargs& kwargs)> fun; std::function<std::vector<array>(const Args&, const Kwargs& kwargs)> fun;
void export_function(const Args& args, const Kwargs& kwargs); void export_function(const Args& args, const Kwargs& kwargs);
void export_with_callback(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape);
std::set<std::uintptr_t> constants; std::set<std::uintptr_t> constants;
int count{0}; int count{0};
bool closed{false}; bool closed{false};

View File

@@ -108,6 +108,7 @@ class ParallelFileReader : public Reader {
class FileWriter : public Writer { class FileWriter : public Writer {
public: public:
explicit FileWriter() {}
explicit FileWriter(std::string file_path) explicit FileWriter(std::string file_path)
: fd_(open( : fd_(open(
file_path.c_str(), file_path.c_str(),

View File

@@ -2,7 +2,9 @@
#include <nanobind/nanobind.h> #include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h> #include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h> #include <nanobind/stl/string.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/unordered_map.h> #include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h> #include <nanobind/stl/vector.h>
#include <fstream> #include <fstream>
@@ -129,6 +131,32 @@ auto wrap_export_function(nb::callable fun) {
} }
void init_export(nb::module_& m) { void init_export(nb::module_& m) {
m.def(
"export_function",
[](const nb::callable& callback,
const nb::callable& fun,
const nb::args& args,
bool shapeless,
const nb::kwargs& kwargs) {
auto [args_, kwargs_] =
validate_and_extract_inputs(args, kwargs, "[export_function]");
auto wrapped_callback =
[callback](const mx::ExportCallbackInput& input) {
return callback(input);
};
mx::export_function(
wrapped_callback,
wrap_export_function(fun),
args_,
kwargs_,
shapeless);
},
"callback"_a,
"fun"_a,
"args"_a,
nb::kw_only(),
"shapeless"_a = false,
"kwargs"_a);
m.def( m.def(
"export_function", "export_function",
[](const std::string& file, [](const std::string& file,