mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-11 03:36:40 +08:00
allow compiling lambdas in C++ (#1650)
* allow compiling lambdas in C++ * fix test * more tests * auto detect capture-less lambda
This commit is contained in:
parent
fd3377dd1f
commit
69a2991614
@ -208,8 +208,7 @@ std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
|
||||
using FunType = T (*)(U...);
|
||||
const FunType* fun_ptr = fun.template target<FunType>();
|
||||
if (fun_ptr == nullptr) {
|
||||
throw std::invalid_argument(
|
||||
"[compile] Cannot compile a non-addressable function.");
|
||||
return 0;
|
||||
}
|
||||
return reinterpret_cast<std::uintptr_t>(*fun_ptr);
|
||||
}
|
||||
@ -817,17 +816,28 @@ void compile_validate_shapeless(const std::vector<array>& tape) {
|
||||
}
|
||||
}
|
||||
|
||||
bool skip_compile() {
|
||||
return compile_mode() == CompileMode::disabled ||
|
||||
!(compile_available_for_device(default_device()));
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless /* = false */,
|
||||
std::vector<uint64_t> constants /* = {} */) {
|
||||
if (compile_mode() == CompileMode::disabled ||
|
||||
!(compile_available_for_device(default_device()))) {
|
||||
if (skip_compile()) {
|
||||
return fun;
|
||||
}
|
||||
return [fun, fun_id, shapeless, constants = std::move(constants)](
|
||||
const std::vector<array>& inputs) {
|
||||
if (!fun) {
|
||||
throw std::invalid_argument(
|
||||
"[compile] Cannot compile a function without a target.");
|
||||
}
|
||||
|
||||
return [fun = std::move(fun),
|
||||
fun_id,
|
||||
shapeless,
|
||||
constants = std::move(constants)](const std::vector<array>& inputs) {
|
||||
// If the inputs are tracers, trace the original graph
|
||||
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
|
||||
return in.is_tracer();
|
||||
@ -889,13 +899,41 @@ void compile_clear_cache() {
|
||||
} // namespace detail
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
bool shapeless /* false */) {
|
||||
if (detail::compile_mode() == CompileMode::disabled) {
|
||||
if (detail::skip_compile()) {
|
||||
return fun;
|
||||
}
|
||||
auto fun_id = detail::get_function_address(fun);
|
||||
return detail::compile(fun, fun_id, shapeless);
|
||||
if (fun_id) {
|
||||
// If the function has an addressable target then no need to manage it's
|
||||
// lifetime
|
||||
return detail::compile(std::move(fun), fun_id, shapeless);
|
||||
} else {
|
||||
auto pfun = std::shared_ptr<
|
||||
std::function<std::vector<array>(const std::vector<array>&)>>(
|
||||
new std::function<std::vector<array>(const std::vector<array>&)>{fun},
|
||||
[](auto p) {
|
||||
detail::compile_erase(reinterpret_cast<std::uintptr_t>(p));
|
||||
delete p;
|
||||
});
|
||||
fun_id = reinterpret_cast<std::uintptr_t>(pfun.get());
|
||||
return detail::compile(
|
||||
[pfun = std::move(pfun)](const auto& inputs) {
|
||||
return (*pfun)(inputs);
|
||||
},
|
||||
fun_id,
|
||||
shapeless);
|
||||
}
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
std::vector<array>(fun)(const std::vector<array>&),
|
||||
bool shapeless /* = false */) {
|
||||
if (detail::skip_compile()) {
|
||||
return fun;
|
||||
}
|
||||
return detail::compile(fun, reinterpret_cast<std::uintptr_t>(fun), shapeless);
|
||||
}
|
||||
|
||||
void disable_compile() {
|
||||
|
@ -10,9 +10,24 @@ enum class CompileMode { disabled, no_simplify, no_fuse, enabled };
|
||||
|
||||
/** Compile takes a function and returns a compiled function. */
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
bool shapeless = false);
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
std::vector<array>(fun)(const std::vector<array>&),
|
||||
bool shapeless = false);
|
||||
|
||||
// Convert capture-less lambdas to function pointers.
|
||||
template <
|
||||
typename F,
|
||||
typename = std::enable_if_t<
|
||||
std::is_convertible_v<F, decltype(+std::declval<F>())>>>
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
F&& f,
|
||||
bool shapeless = false) {
|
||||
return compile(+f, shapeless);
|
||||
}
|
||||
|
||||
/** Globally disable compilation.
|
||||
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
|
||||
* be used to disable compilation.
|
||||
|
@ -9,7 +9,7 @@ namespace mlx::core::detail {
|
||||
// This is not part of the general C++ API as calling with a bad id is a bad
|
||||
// idea.
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
std::function<std::vector<array>(const std::vector<array>&)> fun,
|
||||
std::uintptr_t fun_id,
|
||||
bool shapeless = false,
|
||||
std::vector<uint64_t> constants = {});
|
||||
|
@ -730,3 +730,53 @@ TEST_CASE("test compile change streams") {
|
||||
out = cfun({array(1.0f), array(2.0f)})[0];
|
||||
CHECK_EQ(out.primitive().stream(), s);
|
||||
}
|
||||
|
||||
TEST_CASE("test compile lambda") {
|
||||
auto fun = [](const std::vector<array>& inputs) {
|
||||
return std::vector<array>{abs(inputs[0])};
|
||||
};
|
||||
|
||||
auto out = compile(fun)({array(-1)});
|
||||
CHECK_EQ(out[0].item<int>(), 1);
|
||||
|
||||
decltype(compile(nullptr)) c_local_fun;
|
||||
{
|
||||
auto local_fun = [](const std::vector<array>& inputs) {
|
||||
return std::vector<array>{abs(inputs[0])};
|
||||
};
|
||||
c_local_fun = compile(local_fun);
|
||||
}
|
||||
|
||||
// This is ok even though local_fun is out of scope
|
||||
out = c_local_fun({array(-1)});
|
||||
CHECK_EQ(out[0].item<int>(), 1);
|
||||
|
||||
{
|
||||
int x = 2;
|
||||
auto local_fun = [x](const std::vector<array>& inputs) {
|
||||
return std::vector<array>{inputs[0] + x};
|
||||
};
|
||||
c_local_fun = compile(local_fun);
|
||||
}
|
||||
// Also ok even though local_fun is out of scope.
|
||||
out = c_local_fun({array(0)});
|
||||
CHECK_EQ(out[0].item<int>(), 2);
|
||||
|
||||
int x = 2;
|
||||
auto fun_with_capture = [&x](const std::vector<array>& inputs) {
|
||||
return std::vector<array>{inputs[0] + x};
|
||||
};
|
||||
auto cfun = compile(fun_with_capture);
|
||||
out = cfun({array(0)});
|
||||
CHECK_EQ(out[0].item<int>(), 2);
|
||||
|
||||
// Doesn't recompile
|
||||
x = 3;
|
||||
out = cfun({array(0)});
|
||||
CHECK_EQ(out[0].item<int>(), 2);
|
||||
|
||||
// Recompiles
|
||||
auto cfun2 = compile(fun_with_capture);
|
||||
out = cfun2({array(0)});
|
||||
CHECK_EQ(out[0].item<int>(), 3);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user