mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
export with callback
This commit is contained in:
151
mlx/export.cpp
151
mlx/export.cpp
@@ -3,6 +3,7 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include "mlx/compile_impl.h"
|
#include "mlx/compile_impl.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
|
#include "mlx/graph_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
#include "mlx/version.h"
|
#include "mlx/version.h"
|
||||||
@@ -198,6 +199,32 @@ void serialize_primitive(Writer& os, const Primitive& p) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Possible types for a Primitive's state
|
||||||
|
using StateT = std::variant<int, float, std::vector<int>>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void extract_state(T v, std::vector<StateT>& state) {
|
||||||
|
if constexpr (std::is_arithmetic_v<T>) {
|
||||||
|
state.push_back(v);
|
||||||
|
} else if constexpr (std::is_enum_v<T>) {
|
||||||
|
state.push_back(static_cast<int>(v));
|
||||||
|
}
|
||||||
|
// } else if constexpr (is_iterable<T>) {
|
||||||
|
// state.push_back(v);
|
||||||
|
// } else if constexpr (is_pair<T> || is_tuple<T>) {
|
||||||
|
// std::apply([&os](auto&... x) { (..., extract_state(os, state)); }, v);
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::vector<StateT> extract_state(const Primitive& p) {
|
||||||
|
std::vector<StateT> state;
|
||||||
|
if constexpr (has_state<T>) {
|
||||||
|
extract_state(static_cast<const T&>(p).state(), state);
|
||||||
|
}
|
||||||
|
return state;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::shared_ptr<T> deserialize_primitive(Reader& is, Stream s) {
|
std::shared_ptr<T> deserialize_primitive(Reader& is, Stream s) {
|
||||||
if constexpr (has_state<T>) {
|
if constexpr (has_state<T>) {
|
||||||
@@ -545,10 +572,66 @@ FunctionExporter::FunctionExporter(
|
|||||||
write_header(os, count, shapeless);
|
write_header(os, count, shapeless);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FunctionExporter::FunctionExporter(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
|
||||||
|
bool shapeless)
|
||||||
|
: callback(callback),
|
||||||
|
fun(std::move(fun)),
|
||||||
|
ftable(std::make_shared<FunctionTable>(shapeless)) {}
|
||||||
|
|
||||||
void FunctionExporter::close() {
|
void FunctionExporter::close() {
|
||||||
closed = true;
|
closed = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void FunctionExporter::export_with_callback(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
|
const std::vector<array>& tape) {
|
||||||
|
NodeNamer namer{};
|
||||||
|
auto to_vector_data = [&namer](const auto& arrays) {
|
||||||
|
std::vector<std::tuple<std::string, Shape, Dtype>> data;
|
||||||
|
for (auto& a : arrays) {
|
||||||
|
data.emplace_back(namer.get_name(a), a.shape(), a.dtype());
|
||||||
|
}
|
||||||
|
return data;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Callback on the inputs
|
||||||
|
callback({{"inputs", to_vector_data(inputs)}});
|
||||||
|
|
||||||
|
// Callback on the outputs
|
||||||
|
callback({{"outputs", to_vector_data(outputs)}});
|
||||||
|
|
||||||
|
// Callback on the constants
|
||||||
|
{
|
||||||
|
std::unordered_set<std::uintptr_t> input_set;
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
input_set.insert(in.id());
|
||||||
|
}
|
||||||
|
std::vector<std::pair<std::string, array>> constants;
|
||||||
|
for (auto& arr : tape) {
|
||||||
|
if (arr.has_primitive() || input_set.find(arr.id()) != input_set.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
constants.emplace_back(namer.get_name(arr), arr);
|
||||||
|
}
|
||||||
|
callback({{"constants", constants}});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& arr : tape) {
|
||||||
|
if (!arr.has_primitive()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
callback({
|
||||||
|
{"inputs", to_vector_data(arr.inputs())},
|
||||||
|
{"outputs", to_vector_data(arr.outputs())},
|
||||||
|
{"name", arr.primitive().name()}
|
||||||
|
////// {"state": []});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
||||||
if (closed) {
|
if (closed) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
@@ -592,10 +675,18 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
|||||||
|
|
||||||
detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3);
|
detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3);
|
||||||
|
|
||||||
// Update header
|
// Update the table entry
|
||||||
|
fentry.kwarg_keys = std::move(kwarg_keys);
|
||||||
|
fentry.inputs = trace_inputs;
|
||||||
|
|
||||||
count++;
|
count++;
|
||||||
|
|
||||||
// Overwrite the header
|
if (callback) {
|
||||||
|
export_with_callback(trace_inputs, trace_outputs, tape);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the header
|
||||||
auto pos = os.tell();
|
auto pos = os.tell();
|
||||||
os.seek(0);
|
os.seek(0);
|
||||||
write_header(os, count, ftable->shapeless);
|
write_header(os, count, ftable->shapeless);
|
||||||
@@ -616,10 +707,6 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
|||||||
serialize(os, trace_inputs);
|
serialize(os, trace_inputs);
|
||||||
serialize(os, arrays_to_ids(trace_outputs));
|
serialize(os, arrays_to_ids(trace_outputs));
|
||||||
|
|
||||||
// Update the table entry
|
|
||||||
fentry.kwarg_keys = std::move(kwarg_keys);
|
|
||||||
fentry.inputs = std::move(trace_inputs);
|
|
||||||
|
|
||||||
std::unordered_set<std::uintptr_t> input_set(
|
std::unordered_set<std::uintptr_t> input_set(
|
||||||
trace_input_ids.begin(), trace_input_ids.end());
|
trace_input_ids.begin(), trace_input_ids.end());
|
||||||
|
|
||||||
@@ -730,6 +817,58 @@ void export_function(
|
|||||||
exporter(file, fun, shapeless)(args, kwargs);
|
exporter(file, fun, shapeless)(args, kwargs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Args&)>& fun,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
return FunctionExporter{
|
||||||
|
callback,
|
||||||
|
[fun](const Args& args, const Kwargs&) { return fun(args); },
|
||||||
|
shapeless};
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Kwargs&)>& fun,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
return exporter(
|
||||||
|
callback,
|
||||||
|
[fun](const Args&, const Kwargs kwargs) { return fun(kwargs); },
|
||||||
|
shapeless);
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
return FunctionExporter{callback, fun, shapeless};
|
||||||
|
}
|
||||||
|
|
||||||
|
void export_function(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Args&)>& fun,
|
||||||
|
const Args& args,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
exporter(callback, fun, shapeless)(args);
|
||||||
|
}
|
||||||
|
|
||||||
|
void export_function(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Kwargs&)>& fun,
|
||||||
|
const Kwargs& kwargs,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
exporter(callback, fun, shapeless)(kwargs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void export_function(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
|
||||||
|
const Args& args,
|
||||||
|
const Kwargs& kwargs,
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
exporter(callback, fun, shapeless)(args, kwargs);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> ImportedFunction::operator()(const Kwargs& kwargs) const {
|
std::vector<array> ImportedFunction::operator()(const Kwargs& kwargs) const {
|
||||||
return this->operator()({}, kwargs);
|
return this->operator()({}, kwargs);
|
||||||
}
|
}
|
||||||
|
|||||||
48
mlx/export.h
48
mlx/export.h
@@ -10,6 +10,13 @@ namespace mlx::core {
|
|||||||
|
|
||||||
using Args = std::vector<array>;
|
using Args = std::vector<array>;
|
||||||
using Kwargs = std::unordered_map<std::string, array>;
|
using Kwargs = std::unordered_map<std::string, array>;
|
||||||
|
using ExportCallbackInput = std::unordered_map<
|
||||||
|
std::string,
|
||||||
|
std::variant<
|
||||||
|
std::vector<std::tuple<std::string, Shape, Dtype>>,
|
||||||
|
std::vector<std::pair<std::string, array>>,
|
||||||
|
std::string>>;
|
||||||
|
using ExportCallback = std::function<void(const ExportCallbackInput&)>;
|
||||||
|
|
||||||
struct FunctionExporter;
|
struct FunctionExporter;
|
||||||
|
|
||||||
@@ -61,6 +68,47 @@ struct ImportedFunction;
|
|||||||
*/
|
*/
|
||||||
ImportedFunction import_function(const std::string& file);
|
ImportedFunction import_function(const std::string& file);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Make an exporter to export multiple traces of a given function with the same
|
||||||
|
* callback.
|
||||||
|
*/
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Args&)>& fun,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Kwargs&)>& fun,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
|
FunctionExporter exporter(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Export a function with a callback.
|
||||||
|
*/
|
||||||
|
void export_function(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Args&)>& fun,
|
||||||
|
const Args& args,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
|
void export_function(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Kwargs&)>& fun,
|
||||||
|
const Kwargs& kwargs,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
|
void export_function(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
|
||||||
|
const Args& args,
|
||||||
|
const Kwargs& kwargs,
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|
||||||
#include "mlx/export_impl.h"
|
#include "mlx/export_impl.h"
|
||||||
|
|||||||
@@ -38,13 +38,39 @@ struct FunctionExporter {
|
|||||||
const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
|
||||||
bool shapeless);
|
bool shapeless);
|
||||||
|
|
||||||
|
friend FunctionExporter exporter(
|
||||||
|
const ExportCallback&,
|
||||||
|
const std::function<std::vector<array>(const Args&)>&,
|
||||||
|
bool shapeless);
|
||||||
|
|
||||||
|
friend FunctionExporter exporter(
|
||||||
|
const ExportCallback&,
|
||||||
|
const std::function<std::vector<array>(const Kwargs&)>&,
|
||||||
|
bool shapeless);
|
||||||
|
|
||||||
|
friend FunctionExporter exporter(
|
||||||
|
const ExportCallback&,
|
||||||
|
const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
|
||||||
|
bool shapeless);
|
||||||
|
|
||||||
FunctionExporter(
|
FunctionExporter(
|
||||||
const std::string& file,
|
const std::string& file,
|
||||||
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
|
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
|
||||||
bool shapeless);
|
bool shapeless);
|
||||||
|
|
||||||
|
FunctionExporter(
|
||||||
|
const ExportCallback& callback,
|
||||||
|
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
|
||||||
|
bool shapeless);
|
||||||
|
|
||||||
io::FileWriter os;
|
io::FileWriter os;
|
||||||
|
ExportCallback callback;
|
||||||
std::function<std::vector<array>(const Args&, const Kwargs& kwargs)> fun;
|
std::function<std::vector<array>(const Args&, const Kwargs& kwargs)> fun;
|
||||||
void export_function(const Args& args, const Kwargs& kwargs);
|
void export_function(const Args& args, const Kwargs& kwargs);
|
||||||
|
void export_with_callback(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
|
const std::vector<array>& tape);
|
||||||
std::set<std::uintptr_t> constants;
|
std::set<std::uintptr_t> constants;
|
||||||
int count{0};
|
int count{0};
|
||||||
bool closed{false};
|
bool closed{false};
|
||||||
|
|||||||
@@ -108,6 +108,7 @@ class ParallelFileReader : public Reader {
|
|||||||
|
|
||||||
class FileWriter : public Writer {
|
class FileWriter : public Writer {
|
||||||
public:
|
public:
|
||||||
|
explicit FileWriter() {}
|
||||||
explicit FileWriter(std::string file_path)
|
explicit FileWriter(std::string file_path)
|
||||||
: fd_(open(
|
: fd_(open(
|
||||||
file_path.c_str(),
|
file_path.c_str(),
|
||||||
|
|||||||
@@ -2,7 +2,9 @@
|
|||||||
#include <nanobind/nanobind.h>
|
#include <nanobind/nanobind.h>
|
||||||
#include <nanobind/stl/optional.h>
|
#include <nanobind/stl/optional.h>
|
||||||
#include <nanobind/stl/string.h>
|
#include <nanobind/stl/string.h>
|
||||||
|
#include <nanobind/stl/tuple.h>
|
||||||
#include <nanobind/stl/unordered_map.h>
|
#include <nanobind/stl/unordered_map.h>
|
||||||
|
#include <nanobind/stl/variant.h>
|
||||||
#include <nanobind/stl/vector.h>
|
#include <nanobind/stl/vector.h>
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
@@ -129,6 +131,32 @@ auto wrap_export_function(nb::callable fun) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void init_export(nb::module_& m) {
|
void init_export(nb::module_& m) {
|
||||||
|
m.def(
|
||||||
|
"export_function",
|
||||||
|
[](const nb::callable& callback,
|
||||||
|
const nb::callable& fun,
|
||||||
|
const nb::args& args,
|
||||||
|
bool shapeless,
|
||||||
|
const nb::kwargs& kwargs) {
|
||||||
|
auto [args_, kwargs_] =
|
||||||
|
validate_and_extract_inputs(args, kwargs, "[export_function]");
|
||||||
|
auto wrapped_callback =
|
||||||
|
[callback](const mx::ExportCallbackInput& input) {
|
||||||
|
return callback(input);
|
||||||
|
};
|
||||||
|
mx::export_function(
|
||||||
|
wrapped_callback,
|
||||||
|
wrap_export_function(fun),
|
||||||
|
args_,
|
||||||
|
kwargs_,
|
||||||
|
shapeless);
|
||||||
|
},
|
||||||
|
"callback"_a,
|
||||||
|
"fun"_a,
|
||||||
|
"args"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"shapeless"_a = false,
|
||||||
|
"kwargs"_a);
|
||||||
m.def(
|
m.def(
|
||||||
"export_function",
|
"export_function",
|
||||||
[](const std::string& file,
|
[](const std::string& file,
|
||||||
|
|||||||
Reference in New Issue
Block a user