Skip compile when transforming (#635)

* skip compile when transforming

* simplify message
This commit is contained in:
Awni Hannun
2024-02-05 21:28:37 -08:00
committed by GitHub
parent 316ff490b3
commit 146bd69470
2 changed files with 46 additions and 33 deletions

View File

@@ -587,3 +587,39 @@ TEST_CASE("test compile repeat input") {
auto expected_out = repeat_input_to_compiled({x})[0];
CHECK(allclose(out, expected_out).item<bool>());
}
auto compile_unary_inner(const std::vector<array>& inputs) {
auto x = inputs[0];
return std::vector<array>{exp(exp(x))};
}
auto compile_unary_outer(const std::vector<array>& inputs) {
auto cfun = compile(compile_unary_inner);
return cfun(cfun(inputs));
}
TEST_CASE("test compile compiled function") {
auto cfun = compile(compile_unary_outer);
auto x = array({1.0f});
auto out = cfun({x})[0];
auto& p = out.primitive();
CHECK_EQ(typeid(p), typeid(Compiled));
CHECK_EQ(out.inputs()[0].id(), x.id());
}
auto grad_unary_compiled(const std::vector<array>& inputs) {
auto gradfn = value_and_grad(compile(compile_unary_inner));
auto [out, grad] = gradfn(inputs);
return std::vector{out[0], grad[0]};
}
TEST_CASE("test transform compiled function") {
auto cfun = compile(grad_unary_compiled);
auto x = array(1.0f);
auto outs = cfun({x});
auto& p = outs[0].primitive();
CHECK_EQ(typeid(p), typeid(Compiled));
CHECK_EQ(outs[0].siblings()[0].id(), outs[1].id());
CHECK(!outs[0].inputs()[0].has_primitive());
CHECK(!outs[0].inputs()[1].has_primitive());
}