mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-09 06:54:26 +08:00
allow compiling lambdas in C++ (#1650)
* allow compiling lambdas in C++ * fix test * more tests * auto detect capture-less lambda
This commit is contained in:
@@ -730,3 +730,53 @@ TEST_CASE("test compile change streams") {
|
||||
out = cfun({array(1.0f), array(2.0f)})[0];
|
||||
CHECK_EQ(out.primitive().stream(), s);
|
||||
}
|
||||
|
||||
TEST_CASE("test compile lambda") {
|
||||
auto fun = [](const std::vector<array>& inputs) {
|
||||
return std::vector<array>{abs(inputs[0])};
|
||||
};
|
||||
|
||||
auto out = compile(fun)({array(-1)});
|
||||
CHECK_EQ(out[0].item<int>(), 1);
|
||||
|
||||
decltype(compile(nullptr)) c_local_fun;
|
||||
{
|
||||
auto local_fun = [](const std::vector<array>& inputs) {
|
||||
return std::vector<array>{abs(inputs[0])};
|
||||
};
|
||||
c_local_fun = compile(local_fun);
|
||||
}
|
||||
|
||||
// This is ok even though local_fun is out of scope
|
||||
out = c_local_fun({array(-1)});
|
||||
CHECK_EQ(out[0].item<int>(), 1);
|
||||
|
||||
{
|
||||
int x = 2;
|
||||
auto local_fun = [x](const std::vector<array>& inputs) {
|
||||
return std::vector<array>{inputs[0] + x};
|
||||
};
|
||||
c_local_fun = compile(local_fun);
|
||||
}
|
||||
// Also ok even though local_fun is out of scope.
|
||||
out = c_local_fun({array(0)});
|
||||
CHECK_EQ(out[0].item<int>(), 2);
|
||||
|
||||
int x = 2;
|
||||
auto fun_with_capture = [&x](const std::vector<array>& inputs) {
|
||||
return std::vector<array>{inputs[0] + x};
|
||||
};
|
||||
auto cfun = compile(fun_with_capture);
|
||||
out = cfun({array(0)});
|
||||
CHECK_EQ(out[0].item<int>(), 2);
|
||||
|
||||
// Doesn't recompile
|
||||
x = 3;
|
||||
out = cfun({array(0)});
|
||||
CHECK_EQ(out[0].item<int>(), 2);
|
||||
|
||||
// Recompiles
|
||||
auto cfun2 = compile(fun_with_capture);
|
||||
out = cfun2({array(0)});
|
||||
CHECK_EQ(out[0].item<int>(), 3);
|
||||
}
|
||||
|
Reference in New Issue
Block a user