Compile now can attach arbitrary data to an entry (#2634)

This commit is contained in:
Angelos Katharopoulos
2025-09-30 13:33:27 -07:00
committed by GitHub
parent dc371ae7a5
commit eb24267b56
5 changed files with 130 additions and 31 deletions

View File

@@ -8,6 +8,10 @@
namespace mlx::core::detail {
using ArraysAndExtra = std::pair<std::vector<array>, std::shared_ptr<void>>;
using ArrayFnWithExtra =
std::function<ArraysAndExtra(const std::vector<array>&)>;
// This is not part of the general C++ API as calling with a bad id is a bad
// idea.
std::function<std::vector<array>(const std::vector<array>&)> compile(
@@ -16,6 +20,12 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
bool shapeless = false,
std::vector<uint64_t> constants = {});
ArrayFnWithExtra compile(
ArrayFnWithExtra fun,
std::uintptr_t fun_id,
bool shapeless,
std::vector<uint64_t> constants);
// Erase cached compile functions
void compile_erase(std::uintptr_t fun_id);
@@ -25,8 +35,9 @@ void compile_clear_cache();
bool compile_available_for_device(const Device& device);
std::pair<std::vector<array>, std::vector<array>> compile_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
compile_trace(
const ArrayFnWithExtra& fun,
const std::vector<array>& inputs,
bool shapeless);