Export with callback (#2612)

* export with callback

* export with callback

* Add types, fix kwarg ordering bug + test

* cleanup, test, fix

* typos
This commit is contained in:
Awni Hannun
2025-10-08 19:24:33 -07:00
committed by GitHub
parent 85a8824a8c
commit e89e8b4272
10 changed files with 370 additions and 33 deletions

View File

@@ -1,6 +1,6 @@
// Copyright © 2025 Apple Inc.
// This file must not include any host-only code, utilies that work under both
// This file must not include any host-only code, utilities that work under both
// host and device can be put here.
//
// See more about the requirements at:
@@ -202,7 +202,7 @@ struct Limits<
}
};
// CUDA 11 does not have host side arithmatic operators for half types.
// CUDA 11 does not have host side arithmetic operators for half types.
template <typename T>
struct Limits<
T,

View File

@@ -1,8 +1,8 @@
// Copyright © 2025 Apple Inc.
// This file includes host-only utilies for writing CUDA kernels, the difference
// from backend/cuda/device/utils.cuh is that the latter file only include
// device-only code.
// This file includes host-only utilities for writing CUDA kernels, the
// difference from backend/cuda/device/utils.cuh is that the latter file only
// include device-only code.
#pragma once

View File

@@ -1,6 +1,6 @@
// Copyright © 2025 Apple Inc.
// This file include utilies that are used by C++ code (i.e. .cpp files).
// This file include utilities that are used by C++ code (i.e. .cpp files).
#pragma once

View File

@@ -3,6 +3,7 @@
#include <map>
#include "mlx/compile_impl.h"
#include "mlx/fast_primitives.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
#include "mlx/version.h"
@@ -13,6 +14,7 @@
#primitive, { \
serialize_primitive<primitive>, \
deserialize_primitive<primitive>, \
primitive_state<primitive>, \
{__VA_ARGS__} \
} \
}
@@ -34,15 +36,20 @@ struct PrimitiveSerializer {
using Serializer = std::function<void(Writer&, const Primitive&)>;
using Deserializer =
std::function<std::shared_ptr<Primitive>(Reader&, Stream s)>;
using StateExtractor = std::function<std::vector<StateT>(const Primitive&)>;
PrimitiveSerializer(
Serializer serialize,
Deserializer deserialize,
StateExtractor extract_state,
std::vector<std::string> keys = {})
: serialize(std::move(serialize)),
deserialize(std::move(deserialize)),
extract_state(std::move(extract_state)),
keys(std::move(keys)) {};
Serializer serialize;
Deserializer deserialize;
StateExtractor extract_state;
std::vector<std::string> keys;
};
@@ -198,6 +205,32 @@ void serialize_primitive(Writer& os, const Primitive& p) {
}
}
template <typename T>
void extract_state(const T state, std::vector<StateT>& unpacked_state) {
if constexpr (std::is_arithmetic_v<T>) {
unpacked_state.push_back(state);
} else if constexpr (std::is_enum_v<T>) {
unpacked_state.push_back(static_cast<int>(state));
} else if constexpr (is_iterable<T>) {
unpacked_state.push_back(state);
} else if constexpr (is_pair<T> || is_tuple<T>) {
std::apply(
[&unpacked_state](auto&... x) {
(..., extract_state(x, unpacked_state));
},
state);
}
}
template <typename T>
std::vector<StateT> primitive_state(const Primitive& p) {
std::vector<StateT> state;
if constexpr (has_state<T>) {
extract_state(static_cast<const T&>(p).state(), state);
}
return state;
}
template <typename T>
std::shared_ptr<T> deserialize_primitive(Reader& is, Stream s) {
if constexpr (has_state<T>) {
@@ -383,6 +416,23 @@ struct PrimitiveFactory {
"[import_function] Unable to deserialize primitive " + name);
}
};
std::pair<std::string, std::vector<StateT>> extract_state(
const std::shared_ptr<Primitive>& p) {
std::string name = p->name();
name = name.substr(0, name.find(' '));
if (auto it = name_remap.find(name); it != name_remap.end()) {
name = it->second;
}
if (auto it = factory.find(name); it != factory.end()) {
auto state = it->second.extract_state(*p);
return {name, state};
} else {
throw std::invalid_argument(
"[export_function] Unable to get state for primitive " + name);
}
};
};
void write_header(Writer& os, int count, bool shapeless) {
@@ -416,8 +466,10 @@ struct FunctionTable {
};
bool shapeless;
std::unordered_map<int, std::vector<Function>> table;
Function* find(const Args& args, const Kwargs& kwargs);
std::pair<Function&, bool> emplace(const Args& args, const Kwargs& kwargs);
Function* find(const Args& args, const std::map<std::string, array>& kwargs);
std::pair<Function&, bool> emplace(
const Args& args,
const std::map<std::string, array>& kwargs);
void insert(
std::vector<std::string> kwarg_keys,
std::vector<array> inputs,
@@ -453,12 +505,15 @@ struct FunctionTable {
}
private:
bool match(const Args& args, const Kwargs& kwargs, const Function& fun);
bool match(
const Args& args,
const std::map<std::string, array>& kwargs,
const Function& fun);
};
bool FunctionTable::match(
const Args& args,
const Kwargs& kwargs,
const std::map<std::string, array>& kwargs,
const Function& fun) {
for (auto& k : fun.kwarg_keys) {
if (kwargs.find(k) == kwargs.end()) {
@@ -486,9 +541,7 @@ bool FunctionTable::match(
return false;
}
}
auto sorted_kwargs =
std::map<std::string, array>(kwargs.begin(), kwargs.end());
for (auto& [_, in] : sorted_kwargs) {
for (auto& [_, in] : kwargs) {
if (!match_inputs(in, fun.inputs[i++])) {
return false;
}
@@ -499,7 +552,7 @@ bool FunctionTable::match(
std::pair<FunctionTable::Function&, bool> FunctionTable::emplace(
const Args& args,
const Kwargs& kwargs) {
const std::map<std::string, array>& kwargs) {
auto n_inputs = args.size() + kwargs.size();
auto [it, _] = table.emplace(n_inputs, std::vector<Function>{});
auto& funs_vec = it->second;
@@ -516,7 +569,7 @@ std::pair<FunctionTable::Function&, bool> FunctionTable::emplace(
FunctionTable::Function* FunctionTable::find(
const Args& args,
const Kwargs& kwargs) {
const std::map<std::string, array>& kwargs) {
auto n_inputs = args.size() + kwargs.size();
auto it = table.find(n_inputs);
if (it == table.end()) {
@@ -545,16 +598,86 @@ FunctionExporter::FunctionExporter(
write_header(os, count, shapeless);
}
FunctionExporter::FunctionExporter(
const ExportCallback& callback,
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
bool shapeless)
: callback(callback),
fun(std::move(fun)),
ftable(std::make_shared<FunctionTable>(shapeless)) {}
void FunctionExporter::close() {
closed = true;
};
void FunctionExporter::export_with_callback(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::vector<std::string>& kwarg_keys) {
NodeNamer namer{};
auto to_vector_data = [&namer](const auto& arrays) {
std::vector<std::tuple<std::string, Shape, Dtype>> data;
for (auto& a : arrays) {
data.emplace_back(namer.get_name(a), a.shape(), a.dtype());
}
return data;
};
// Callback on the inputs
callback({{"type", "inputs"}, {"inputs", to_vector_data(inputs)}});
std::vector<std::pair<std::string, std::string>> keyword_inputs;
for (int i = inputs.size() - kwarg_keys.size(), j = 0; i < inputs.size();
++i, ++j) {
keyword_inputs.emplace_back(kwarg_keys[j], namer.get_name(inputs[i]));
}
callback({{"type", "keyword_inputs"}, {"keywords", keyword_inputs}});
// Callback on the outputs
callback({{"type", "outputs"}, {"outputs", to_vector_data(outputs)}});
// Callback on the constants
{
std::unordered_set<std::uintptr_t> input_set;
for (auto& in : inputs) {
input_set.insert(in.id());
}
std::vector<std::pair<std::string, array>> new_constants;
for (auto& arr : tape) {
if (arr.has_primitive() || input_set.find(arr.id()) != input_set.end()) {
continue;
}
if (constants.insert(arr.id()).second) {
new_constants.emplace_back(namer.get_name(arr), arr);
}
}
callback({{"type", "constants"}, {"constants", new_constants}});
}
auto factory = PrimitiveFactory();
// Callback for each primitive in the tape
for (auto& arr : tape) {
if (!arr.has_primitive()) {
continue;
}
auto [name, state] = factory.extract_state(arr.primitive_ptr());
callback(
{{"type", "primitive"},
{"inputs", to_vector_data(arr.inputs())},
{"outputs", to_vector_data(arr.outputs())},
{"name", name},
{"arguments", state}});
}
}
void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
if (closed) {
throw std::runtime_error(
"[export_function] Attempting to write after exporting is closed.");
}
auto [fentry, inserted] = ftable->emplace(args, kwargs);
auto sorted_kwargs =
std::map<std::string, array>(kwargs.begin(), kwargs.end());
auto [fentry, inserted] = ftable->emplace(args, sorted_kwargs);
if (!inserted) {
throw std::runtime_error(
"[export_function] Attempting to export a function twice with "
@@ -564,8 +687,6 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
// Flatten the inputs to the function for tracing
std::vector<std::string> kwarg_keys;
auto inputs = args;
auto sorted_kwargs =
std::map<std::string, array>(kwargs.begin(), kwargs.end());
for (auto& [k, v] : sorted_kwargs) {
kwarg_keys.push_back(k);
inputs.push_back(v);
@@ -592,10 +713,18 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
detail::compile_simplify(tape, parents_map, trace_outputs, /* passes */ 3);
// Update header
// Update the table entry
fentry.kwarg_keys = kwarg_keys;
fentry.inputs = trace_inputs;
count++;
// Overwrite the header
if (callback) {
export_with_callback(trace_inputs, trace_outputs, tape, kwarg_keys);
return;
}
// Update the header
auto pos = os.tell();
os.seek(0);
write_header(os, count, ftable->shapeless);
@@ -616,10 +745,6 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
serialize(os, trace_inputs);
serialize(os, arrays_to_ids(trace_outputs));
// Update the table entry
fentry.kwarg_keys = std::move(kwarg_keys);
fentry.inputs = std::move(trace_inputs);
std::unordered_set<std::uintptr_t> input_set(
trace_input_ids.begin(), trace_input_ids.end());
@@ -730,6 +855,58 @@ void export_function(
exporter(file, fun, shapeless)(args, kwargs);
}
FunctionExporter exporter(
const ExportCallback& callback,
const std::function<std::vector<array>(const Args&)>& fun,
bool shapeless /* = false */) {
return FunctionExporter{
callback,
[fun](const Args& args, const Kwargs&) { return fun(args); },
shapeless};
}
FunctionExporter exporter(
const ExportCallback& callback,
const std::function<std::vector<array>(const Kwargs&)>& fun,
bool shapeless /* = false */) {
return exporter(
callback,
[fun](const Args&, const Kwargs kwargs) { return fun(kwargs); },
shapeless);
}
FunctionExporter exporter(
const ExportCallback& callback,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
bool shapeless /* = false */) {
return FunctionExporter{callback, fun, shapeless};
}
void export_function(
const ExportCallback& callback,
const std::function<std::vector<array>(const Args&)>& fun,
const Args& args,
bool shapeless /* = false */) {
exporter(callback, fun, shapeless)(args);
}
void export_function(
const ExportCallback& callback,
const std::function<std::vector<array>(const Kwargs&)>& fun,
const Kwargs& kwargs,
bool shapeless /* = false */) {
exporter(callback, fun, shapeless)(kwargs);
}
void export_function(
const ExportCallback& callback,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
const Args& args,
const Kwargs& kwargs,
bool shapeless /* = false */) {
exporter(callback, fun, shapeless)(args, kwargs);
}
std::vector<array> ImportedFunction::operator()(const Kwargs& kwargs) const {
return this->operator()({}, kwargs);
}
@@ -741,7 +918,9 @@ std::vector<array> ImportedFunction::operator()(const Args& args) const {
std::vector<array> ImportedFunction::operator()(
const Args& args,
const Kwargs& kwargs) const {
auto* fun = ftable->find(args, kwargs);
auto sorted_kwargs =
std::map<std::string, array>(kwargs.begin(), kwargs.end());
auto* fun = ftable->find(args, sorted_kwargs);
if (fun == nullptr) {
std::ostringstream msg;
msg << "[import_function::call] No imported function found which matches "
@@ -760,7 +939,7 @@ std::vector<array> ImportedFunction::operator()(
}
auto inputs = args;
for (auto& [_, v] : kwargs) {
for (auto& [_, v] : sorted_kwargs) {
inputs.push_back(v);
}
return detail::compile_replace(

View File

@@ -4,6 +4,7 @@
#include <set>
#include <unordered_map>
#include <variant>
#include "mlx/array.h"
namespace mlx::core {
@@ -11,6 +12,30 @@ namespace mlx::core {
using Args = std::vector<array>;
using Kwargs = std::unordered_map<std::string, array>;
// Possible types for a Primitive's state
using StateT = std::variant<
bool,
int,
size_t,
float,
double,
Dtype,
Shape,
Strides,
std::vector<int>,
std::vector<size_t>,
std::string>;
using ExportCallbackInput = std::unordered_map<
std::string,
std::variant<
std::vector<std::tuple<std::string, Shape, Dtype>>,
std::vector<std::pair<std::string, array>>,
std::vector<std::pair<std::string, std::string>>,
std::vector<StateT>,
std::string>>;
using ExportCallback = std::function<void(const ExportCallbackInput&)>;
struct FunctionExporter;
/**
@@ -61,6 +86,47 @@ struct ImportedFunction;
*/
ImportedFunction import_function(const std::string& file);
/**
* Make an exporter to export multiple traces of a given function with the same
* callback.
*/
FunctionExporter exporter(
const ExportCallback& callback,
const std::function<std::vector<array>(const Args&)>& fun,
bool shapeless = false);
FunctionExporter exporter(
const ExportCallback& callback,
const std::function<std::vector<array>(const Kwargs&)>& fun,
bool shapeless = false);
FunctionExporter exporter(
const ExportCallback& callback,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
bool shapeless = false);
/**
* Export a function with a callback.
*/
void export_function(
const ExportCallback& callback,
const std::function<std::vector<array>(const Args&)>& fun,
const Args& args,
bool shapeless = false);
void export_function(
const ExportCallback& callback,
const std::function<std::vector<array>(const Kwargs&)>& fun,
const Kwargs& kwargs,
bool shapeless = false);
void export_function(
const ExportCallback& callback,
const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
const Args& args,
const Kwargs& kwargs,
bool shapeless = false);
} // namespace mlx::core
#include "mlx/export_impl.h"

View File

@@ -38,13 +38,40 @@ struct FunctionExporter {
const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
bool shapeless);
friend FunctionExporter exporter(
const ExportCallback&,
const std::function<std::vector<array>(const Args&)>&,
bool shapeless);
friend FunctionExporter exporter(
const ExportCallback&,
const std::function<std::vector<array>(const Kwargs&)>&,
bool shapeless);
friend FunctionExporter exporter(
const ExportCallback&,
const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
bool shapeless);
FunctionExporter(
const std::string& file,
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
bool shapeless);
FunctionExporter(
const ExportCallback& callback,
std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
bool shapeless);
io::FileWriter os;
ExportCallback callback;
std::function<std::vector<array>(const Args&, const Kwargs& kwargs)> fun;
void export_function(const Args& args, const Kwargs& kwargs);
void export_with_callback(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::vector<std::string>& kwarg_keys);
std::set<std::uintptr_t> constants;
int count{0};
bool closed{false};

View File

@@ -108,6 +108,7 @@ class ParallelFileReader : public Reader {
class FileWriter : public Writer {
public:
explicit FileWriter() {}
explicit FileWriter(std::string file_path)
: fd_(open(
file_path.c_str(),