allow compiling lambdas in C++ (#1650)

* allow compiling lambdas in C++

* fix test

* more tests

* auto detect capture-less lambda
This commit is contained in:
Awni Hannun
2024-12-06 13:13:21 -08:00
committed by GitHub
parent fd3377dd1f
commit 69a2991614
4 changed files with 115 additions and 12 deletions

View File

@@ -208,8 +208,7 @@ std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
using FunType = T (*)(U...); using FunType = T (*)(U...);
const FunType* fun_ptr = fun.template target<FunType>(); const FunType* fun_ptr = fun.template target<FunType>();
if (fun_ptr == nullptr) { if (fun_ptr == nullptr) {
throw std::invalid_argument( return 0;
"[compile] Cannot compile a non-addressable function.");
} }
return reinterpret_cast<std::uintptr_t>(*fun_ptr); return reinterpret_cast<std::uintptr_t>(*fun_ptr);
} }
@@ -817,17 +816,28 @@ void compile_validate_shapeless(const std::vector<array>& tape) {
} }
} }
bool skip_compile() {
return compile_mode() == CompileMode::disabled ||
!(compile_available_for_device(default_device()));
}
std::function<std::vector<array>(const std::vector<array>&)> compile( std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun, std::function<std::vector<array>(const std::vector<array>&)> fun,
std::uintptr_t fun_id, std::uintptr_t fun_id,
bool shapeless /* = false */, bool shapeless /* = false */,
std::vector<uint64_t> constants /* = {} */) { std::vector<uint64_t> constants /* = {} */) {
if (compile_mode() == CompileMode::disabled || if (skip_compile()) {
!(compile_available_for_device(default_device()))) {
return fun; return fun;
} }
return [fun, fun_id, shapeless, constants = std::move(constants)]( if (!fun) {
const std::vector<array>& inputs) { throw std::invalid_argument(
"[compile] Cannot compile a function without a target.");
}
return [fun = std::move(fun),
fun_id,
shapeless,
constants = std::move(constants)](const std::vector<array>& inputs) {
// If the inputs are tracers, trace the original graph // If the inputs are tracers, trace the original graph
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) { if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
return in.is_tracer(); return in.is_tracer();
@@ -889,13 +899,41 @@ void compile_clear_cache() {
} // namespace detail } // namespace detail
std::function<std::vector<array>(const std::vector<array>&)> compile( std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun, std::function<std::vector<array>(const std::vector<array>&)> fun,
bool shapeless /* false */) { bool shapeless /* false */) {
if (detail::compile_mode() == CompileMode::disabled) { if (detail::skip_compile()) {
return fun; return fun;
} }
auto fun_id = detail::get_function_address(fun); auto fun_id = detail::get_function_address(fun);
return detail::compile(fun, fun_id, shapeless); if (fun_id) {
// If the function has an addressable target then no need to manage it's
// lifetime
return detail::compile(std::move(fun), fun_id, shapeless);
} else {
auto pfun = std::shared_ptr<
std::function<std::vector<array>(const std::vector<array>&)>>(
new std::function<std::vector<array>(const std::vector<array>&)>{fun},
[](auto p) {
detail::compile_erase(reinterpret_cast<std::uintptr_t>(p));
delete p;
});
fun_id = reinterpret_cast<std::uintptr_t>(pfun.get());
return detail::compile(
[pfun = std::move(pfun)](const auto& inputs) {
return (*pfun)(inputs);
},
fun_id,
shapeless);
}
}
std::function<std::vector<array>(const std::vector<array>&)> compile(
std::vector<array>(fun)(const std::vector<array>&),
bool shapeless /* = false */) {
if (detail::skip_compile()) {
return fun;
}
return detail::compile(fun, reinterpret_cast<std::uintptr_t>(fun), shapeless);
} }
void disable_compile() { void disable_compile() {

View File

@@ -10,9 +10,24 @@ enum class CompileMode { disabled, no_simplify, no_fuse, enabled };
/** Compile takes a function and returns a compiled function. */ /** Compile takes a function and returns a compiled function. */
std::function<std::vector<array>(const std::vector<array>&)> compile( std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun, std::function<std::vector<array>(const std::vector<array>&)> fun,
bool shapeless = false); bool shapeless = false);
std::function<std::vector<array>(const std::vector<array>&)> compile(
std::vector<array>(fun)(const std::vector<array>&),
bool shapeless = false);
// Convert capture-less lambdas to function pointers.
template <
typename F,
typename = std::enable_if_t<
std::is_convertible_v<F, decltype(+std::declval<F>())>>>
std::function<std::vector<array>(const std::vector<array>&)> compile(
F&& f,
bool shapeless = false) {
return compile(+f, shapeless);
}
/** Globally disable compilation. /** Globally disable compilation.
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
* be used to disable compilation. * be used to disable compilation.

View File

@@ -9,7 +9,7 @@ namespace mlx::core::detail {
// This is not part of the general C++ API as calling with a bad id is a bad // This is not part of the general C++ API as calling with a bad id is a bad
// idea. // idea.
std::function<std::vector<array>(const std::vector<array>&)> compile( std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun, std::function<std::vector<array>(const std::vector<array>&)> fun,
std::uintptr_t fun_id, std::uintptr_t fun_id,
bool shapeless = false, bool shapeless = false,
std::vector<uint64_t> constants = {}); std::vector<uint64_t> constants = {});

View File

@@ -730,3 +730,53 @@ TEST_CASE("test compile change streams") {
out = cfun({array(1.0f), array(2.0f)})[0]; out = cfun({array(1.0f), array(2.0f)})[0];
CHECK_EQ(out.primitive().stream(), s); CHECK_EQ(out.primitive().stream(), s);
} }
TEST_CASE("test compile lambda") {
auto fun = [](const std::vector<array>& inputs) {
return std::vector<array>{abs(inputs[0])};
};
auto out = compile(fun)({array(-1)});
CHECK_EQ(out[0].item<int>(), 1);
decltype(compile(nullptr)) c_local_fun;
{
auto local_fun = [](const std::vector<array>& inputs) {
return std::vector<array>{abs(inputs[0])};
};
c_local_fun = compile(local_fun);
}
// This is ok even though local_fun is out of scope
out = c_local_fun({array(-1)});
CHECK_EQ(out[0].item<int>(), 1);
{
int x = 2;
auto local_fun = [x](const std::vector<array>& inputs) {
return std::vector<array>{inputs[0] + x};
};
c_local_fun = compile(local_fun);
}
// Also ok even though local_fun is out of scope.
out = c_local_fun({array(0)});
CHECK_EQ(out[0].item<int>(), 2);
int x = 2;
auto fun_with_capture = [&x](const std::vector<array>& inputs) {
return std::vector<array>{inputs[0] + x};
};
auto cfun = compile(fun_with_capture);
out = cfun({array(0)});
CHECK_EQ(out[0].item<int>(), 2);
// Doesn't recompile
x = 3;
out = cfun({array(0)});
CHECK_EQ(out[0].item<int>(), 2);
// Recompiles
auto cfun2 = compile(fun_with_capture);
out = cfun2({array(0)});
CHECK_EQ(out[0].item<int>(), 3);
}