allow compiling lambdas in C++ (#1650)

* allow compiling lambdas in C++

* fix test

* more tests

* auto detect capture-less lambda
This commit is contained in:
Awni Hannun
2024-12-06 13:13:21 -08:00
committed by GitHub
parent fd3377dd1f
commit 69a2991614
4 changed files with 115 additions and 12 deletions

View File

@@ -730,3 +730,53 @@ TEST_CASE("test compile change streams") {
out = cfun({array(1.0f), array(2.0f)})[0];
CHECK_EQ(out.primitive().stream(), s);
}
TEST_CASE("test compile lambda") {
auto fun = [](const std::vector<array>& inputs) {
return std::vector<array>{abs(inputs[0])};
};
auto out = compile(fun)({array(-1)});
CHECK_EQ(out[0].item<int>(), 1);
decltype(compile(nullptr)) c_local_fun;
{
auto local_fun = [](const std::vector<array>& inputs) {
return std::vector<array>{abs(inputs[0])};
};
c_local_fun = compile(local_fun);
}
// This is ok even though local_fun is out of scope
out = c_local_fun({array(-1)});
CHECK_EQ(out[0].item<int>(), 1);
{
int x = 2;
auto local_fun = [x](const std::vector<array>& inputs) {
return std::vector<array>{inputs[0] + x};
};
c_local_fun = compile(local_fun);
}
// Also ok even though local_fun is out of scope.
out = c_local_fun({array(0)});
CHECK_EQ(out[0].item<int>(), 2);
int x = 2;
auto fun_with_capture = [&x](const std::vector<array>& inputs) {
return std::vector<array>{inputs[0] + x};
};
auto cfun = compile(fun_with_capture);
out = cfun({array(0)});
CHECK_EQ(out[0].item<int>(), 2);
// Doesn't recompile
x = 3;
out = cfun({array(0)});
CHECK_EQ(out[0].item<int>(), 2);
// Recompiles
auto cfun2 = compile(fun_with_capture);
out = cfun2({array(0)});
CHECK_EQ(out[0].item<int>(), 3);
}