diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 741385844f..44bab0298b 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -208,8 +208,7 @@ std::uintptr_t get_function_address(const std::function& fun) { using FunType = T (*)(U...); const FunType* fun_ptr = fun.template target(); if (fun_ptr == nullptr) { - throw std::invalid_argument( - "[compile] Cannot compile a non-addressable function."); + return 0; } return reinterpret_cast(*fun_ptr); } @@ -817,17 +816,28 @@ void compile_validate_shapeless(const std::vector& tape) { } } +bool skip_compile() { + return compile_mode() == CompileMode::disabled || + !(compile_available_for_device(default_device())); +} + std::function(const std::vector&)> compile( - const std::function(const std::vector&)>& fun, + std::function(const std::vector&)> fun, std::uintptr_t fun_id, bool shapeless /* = false */, std::vector constants /* = {} */) { - if (compile_mode() == CompileMode::disabled || - !(compile_available_for_device(default_device()))) { + if (skip_compile()) { return fun; } - return [fun, fun_id, shapeless, constants = std::move(constants)]( - const std::vector& inputs) { + if (!fun) { + throw std::invalid_argument( + "[compile] Cannot compile a function without a target."); + } + + return [fun = std::move(fun), + fun_id, + shapeless, + constants = std::move(constants)](const std::vector& inputs) { // If the inputs are tracers, trace the original graph if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) { return in.is_tracer(); @@ -889,13 +899,41 @@ void compile_clear_cache() { } // namespace detail std::function(const std::vector&)> compile( - const std::function(const std::vector&)>& fun, + std::function(const std::vector&)> fun, bool shapeless /* false */) { - if (detail::compile_mode() == CompileMode::disabled) { + if (detail::skip_compile()) { return fun; } auto fun_id = detail::get_function_address(fun); - return detail::compile(fun, fun_id, shapeless); + if (fun_id) { + // If the function has an addressable target then no need to manage it's + // lifetime + return detail::compile(std::move(fun), fun_id, shapeless); + } else { + auto pfun = std::shared_ptr< + std::function(const std::vector&)>>( + new std::function(const std::vector&)>{fun}, + [](auto p) { + detail::compile_erase(reinterpret_cast(p)); + delete p; + }); + fun_id = reinterpret_cast(pfun.get()); + return detail::compile( + [pfun = std::move(pfun)](const auto& inputs) { + return (*pfun)(inputs); + }, + fun_id, + shapeless); + } +} + +std::function(const std::vector&)> compile( + std::vector(fun)(const std::vector&), + bool shapeless /* = false */) { + if (detail::skip_compile()) { + return fun; + } + return detail::compile(fun, reinterpret_cast(fun), shapeless); } void disable_compile() { diff --git a/mlx/compile.h b/mlx/compile.h index 1134c20dc4..90b6a9c2aa 100644 --- a/mlx/compile.h +++ b/mlx/compile.h @@ -10,9 +10,24 @@ enum class CompileMode { disabled, no_simplify, no_fuse, enabled }; /** Compile takes a function and returns a compiled function. */ std::function(const std::vector&)> compile( - const std::function(const std::vector&)>& fun, + std::function(const std::vector&)> fun, bool shapeless = false); +std::function(const std::vector&)> compile( + std::vector(fun)(const std::vector&), + bool shapeless = false); + +// Convert capture-less lambdas to function pointers. +template < + typename F, + typename = std::enable_if_t< + std::is_convertible_v())>>> +std::function(const std::vector&)> compile( + F&& f, + bool shapeless = false) { + return compile(+f, shapeless); +} + /** Globally disable compilation. * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also * be used to disable compilation. diff --git a/mlx/compile_impl.h b/mlx/compile_impl.h index 0f18e1dacd..913079bfb1 100644 --- a/mlx/compile_impl.h +++ b/mlx/compile_impl.h @@ -9,7 +9,7 @@ namespace mlx::core::detail { // This is not part of the general C++ API as calling with a bad id is a bad // idea. std::function(const std::vector&)> compile( - const std::function(const std::vector&)>& fun, + std::function(const std::vector&)> fun, std::uintptr_t fun_id, bool shapeless = false, std::vector constants = {}); diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index a1559b7d3f..5bcba68f44 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -730,3 +730,53 @@ TEST_CASE("test compile change streams") { out = cfun({array(1.0f), array(2.0f)})[0]; CHECK_EQ(out.primitive().stream(), s); } + +TEST_CASE("test compile lambda") { + auto fun = [](const std::vector& inputs) { + return std::vector{abs(inputs[0])}; + }; + + auto out = compile(fun)({array(-1)}); + CHECK_EQ(out[0].item(), 1); + + decltype(compile(nullptr)) c_local_fun; + { + auto local_fun = [](const std::vector& inputs) { + return std::vector{abs(inputs[0])}; + }; + c_local_fun = compile(local_fun); + } + + // This is ok even though local_fun is out of scope + out = c_local_fun({array(-1)}); + CHECK_EQ(out[0].item(), 1); + + { + int x = 2; + auto local_fun = [x](const std::vector& inputs) { + return std::vector{inputs[0] + x}; + }; + c_local_fun = compile(local_fun); + } + // Also ok even though local_fun is out of scope. + out = c_local_fun({array(0)}); + CHECK_EQ(out[0].item(), 2); + + int x = 2; + auto fun_with_capture = [&x](const std::vector& inputs) { + return std::vector{inputs[0] + x}; + }; + auto cfun = compile(fun_with_capture); + out = cfun({array(0)}); + CHECK_EQ(out[0].item(), 2); + + // Doesn't recompile + x = 3; + out = cfun({array(0)}); + CHECK_EQ(out[0].item(), 2); + + // Recompiles + auto cfun2 = compile(fun_with_capture); + out = cfun2({array(0)}); + CHECK_EQ(out[0].item(), 3); +}