mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-11 19:56:40 +08:00
allow compiling lambdas in C++ (#1650)
* allow compiling lambdas in C++ * fix test * more tests * auto detect capture-less lambda
This commit is contained in:
parent
fd3377dd1f
commit
69a2991614
@ -208,8 +208,7 @@ std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
|
|||||||
using FunType = T (*)(U...);
|
using FunType = T (*)(U...);
|
||||||
const FunType* fun_ptr = fun.template target<FunType>();
|
const FunType* fun_ptr = fun.template target<FunType>();
|
||||||
if (fun_ptr == nullptr) {
|
if (fun_ptr == nullptr) {
|
||||||
throw std::invalid_argument(
|
return 0;
|
||||||
"[compile] Cannot compile a non-addressable function.");
|
|
||||||
}
|
}
|
||||||
return reinterpret_cast<std::uintptr_t>(*fun_ptr);
|
return reinterpret_cast<std::uintptr_t>(*fun_ptr);
|
||||||
}
|
}
|
||||||
@ -817,17 +816,28 @@ void compile_validate_shapeless(const std::vector<array>& tape) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool skip_compile() {
|
||||||
|
return compile_mode() == CompileMode::disabled ||
|
||||||
|
!(compile_available_for_device(default_device()));
|
||||||
|
}
|
||||||
|
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||||
std::uintptr_t fun_id,
|
std::uintptr_t fun_id,
|
||||||
bool shapeless /* = false */,
|
bool shapeless /* = false */,
|
||||||
std::vector<uint64_t> constants /* = {} */) {
|
std::vector<uint64_t> constants /* = {} */) {
|
||||||
if (compile_mode() == CompileMode::disabled ||
|
if (skip_compile()) {
|
||||||
!(compile_available_for_device(default_device()))) {
|
|
||||||
return fun;
|
return fun;
|
||||||
}
|
}
|
||||||
return [fun, fun_id, shapeless, constants = std::move(constants)](
|
if (!fun) {
|
||||||
const std::vector<array>& inputs) {
|
throw std::invalid_argument(
|
||||||
|
"[compile] Cannot compile a function without a target.");
|
||||||
|
}
|
||||||
|
|
||||||
|
return [fun = std::move(fun),
|
||||||
|
fun_id,
|
||||||
|
shapeless,
|
||||||
|
constants = std::move(constants)](const std::vector<array>& inputs) {
|
||||||
// If the inputs are tracers, trace the original graph
|
// If the inputs are tracers, trace the original graph
|
||||||
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
|
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
|
||||||
return in.is_tracer();
|
return in.is_tracer();
|
||||||
@ -889,13 +899,41 @@ void compile_clear_cache() {
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||||
bool shapeless /* false */) {
|
bool shapeless /* false */) {
|
||||||
if (detail::compile_mode() == CompileMode::disabled) {
|
if (detail::skip_compile()) {
|
||||||
return fun;
|
return fun;
|
||||||
}
|
}
|
||||||
auto fun_id = detail::get_function_address(fun);
|
auto fun_id = detail::get_function_address(fun);
|
||||||
return detail::compile(fun, fun_id, shapeless);
|
if (fun_id) {
|
||||||
|
// If the function has an addressable target then no need to manage it's
|
||||||
|
// lifetime
|
||||||
|
return detail::compile(std::move(fun), fun_id, shapeless);
|
||||||
|
} else {
|
||||||
|
auto pfun = std::shared_ptr<
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)>>(
|
||||||
|
new std::function<std::vector<array>(const std::vector<array>&)>{fun},
|
||||||
|
[](auto p) {
|
||||||
|
detail::compile_erase(reinterpret_cast<std::uintptr_t>(p));
|
||||||
|
delete p;
|
||||||
|
});
|
||||||
|
fun_id = reinterpret_cast<std::uintptr_t>(pfun.get());
|
||||||
|
return detail::compile(
|
||||||
|
[pfun = std::move(pfun)](const auto& inputs) {
|
||||||
|
return (*pfun)(inputs);
|
||||||
|
},
|
||||||
|
fun_id,
|
||||||
|
shapeless);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
|
std::vector<array>(fun)(const std::vector<array>&),
|
||||||
|
bool shapeless /* = false */) {
|
||||||
|
if (detail::skip_compile()) {
|
||||||
|
return fun;
|
||||||
|
}
|
||||||
|
return detail::compile(fun, reinterpret_cast<std::uintptr_t>(fun), shapeless);
|
||||||
}
|
}
|
||||||
|
|
||||||
void disable_compile() {
|
void disable_compile() {
|
||||||
|
@ -10,9 +10,24 @@ enum class CompileMode { disabled, no_simplify, no_fuse, enabled };
|
|||||||
|
|
||||||
/** Compile takes a function and returns a compiled function. */
|
/** Compile takes a function and returns a compiled function. */
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||||
bool shapeless = false);
|
bool shapeless = false);
|
||||||
|
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
|
std::vector<array>(fun)(const std::vector<array>&),
|
||||||
|
bool shapeless = false);
|
||||||
|
|
||||||
|
// Convert capture-less lambdas to function pointers.
|
||||||
|
template <
|
||||||
|
typename F,
|
||||||
|
typename = std::enable_if_t<
|
||||||
|
std::is_convertible_v<F, decltype(+std::declval<F>())>>>
|
||||||
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
|
F&& f,
|
||||||
|
bool shapeless = false) {
|
||||||
|
return compile(+f, shapeless);
|
||||||
|
}
|
||||||
|
|
||||||
/** Globally disable compilation.
|
/** Globally disable compilation.
|
||||||
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
|
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
|
||||||
* be used to disable compilation.
|
* be used to disable compilation.
|
||||||
|
@ -9,7 +9,7 @@ namespace mlx::core::detail {
|
|||||||
// This is not part of the general C++ API as calling with a bad id is a bad
|
// This is not part of the general C++ API as calling with a bad id is a bad
|
||||||
// idea.
|
// idea.
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||||
std::uintptr_t fun_id,
|
std::uintptr_t fun_id,
|
||||||
bool shapeless = false,
|
bool shapeless = false,
|
||||||
std::vector<uint64_t> constants = {});
|
std::vector<uint64_t> constants = {});
|
||||||
|
@ -730,3 +730,53 @@ TEST_CASE("test compile change streams") {
|
|||||||
out = cfun({array(1.0f), array(2.0f)})[0];
|
out = cfun({array(1.0f), array(2.0f)})[0];
|
||||||
CHECK_EQ(out.primitive().stream(), s);
|
CHECK_EQ(out.primitive().stream(), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test compile lambda") {
|
||||||
|
auto fun = [](const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{abs(inputs[0])};
|
||||||
|
};
|
||||||
|
|
||||||
|
auto out = compile(fun)({array(-1)});
|
||||||
|
CHECK_EQ(out[0].item<int>(), 1);
|
||||||
|
|
||||||
|
decltype(compile(nullptr)) c_local_fun;
|
||||||
|
{
|
||||||
|
auto local_fun = [](const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{abs(inputs[0])};
|
||||||
|
};
|
||||||
|
c_local_fun = compile(local_fun);
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is ok even though local_fun is out of scope
|
||||||
|
out = c_local_fun({array(-1)});
|
||||||
|
CHECK_EQ(out[0].item<int>(), 1);
|
||||||
|
|
||||||
|
{
|
||||||
|
int x = 2;
|
||||||
|
auto local_fun = [x](const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{inputs[0] + x};
|
||||||
|
};
|
||||||
|
c_local_fun = compile(local_fun);
|
||||||
|
}
|
||||||
|
// Also ok even though local_fun is out of scope.
|
||||||
|
out = c_local_fun({array(0)});
|
||||||
|
CHECK_EQ(out[0].item<int>(), 2);
|
||||||
|
|
||||||
|
int x = 2;
|
||||||
|
auto fun_with_capture = [&x](const std::vector<array>& inputs) {
|
||||||
|
return std::vector<array>{inputs[0] + x};
|
||||||
|
};
|
||||||
|
auto cfun = compile(fun_with_capture);
|
||||||
|
out = cfun({array(0)});
|
||||||
|
CHECK_EQ(out[0].item<int>(), 2);
|
||||||
|
|
||||||
|
// Doesn't recompile
|
||||||
|
x = 3;
|
||||||
|
out = cfun({array(0)});
|
||||||
|
CHECK_EQ(out[0].item<int>(), 2);
|
||||||
|
|
||||||
|
// Recompiles
|
||||||
|
auto cfun2 = compile(fun_with_capture);
|
||||||
|
out = cfun2({array(0)});
|
||||||
|
CHECK_EQ(out[0].item<int>(), 3);
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user