mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Skip compile when transforming (#635)
* skip compile when transforming * simplify message
This commit is contained in:
@@ -587,3 +587,39 @@ TEST_CASE("test compile repeat input") {
|
||||
auto expected_out = repeat_input_to_compiled({x})[0];
|
||||
CHECK(allclose(out, expected_out).item<bool>());
|
||||
}
|
||||
|
||||
auto compile_unary_inner(const std::vector<array>& inputs) {
|
||||
auto x = inputs[0];
|
||||
return std::vector<array>{exp(exp(x))};
|
||||
}
|
||||
|
||||
auto compile_unary_outer(const std::vector<array>& inputs) {
|
||||
auto cfun = compile(compile_unary_inner);
|
||||
return cfun(cfun(inputs));
|
||||
}
|
||||
|
||||
TEST_CASE("test compile compiled function") {
|
||||
auto cfun = compile(compile_unary_outer);
|
||||
auto x = array({1.0f});
|
||||
auto out = cfun({x})[0];
|
||||
auto& p = out.primitive();
|
||||
CHECK_EQ(typeid(p), typeid(Compiled));
|
||||
CHECK_EQ(out.inputs()[0].id(), x.id());
|
||||
}
|
||||
|
||||
auto grad_unary_compiled(const std::vector<array>& inputs) {
|
||||
auto gradfn = value_and_grad(compile(compile_unary_inner));
|
||||
auto [out, grad] = gradfn(inputs);
|
||||
return std::vector{out[0], grad[0]};
|
||||
}
|
||||
|
||||
TEST_CASE("test transform compiled function") {
|
||||
auto cfun = compile(grad_unary_compiled);
|
||||
auto x = array(1.0f);
|
||||
auto outs = cfun({x});
|
||||
auto& p = outs[0].primitive();
|
||||
CHECK_EQ(typeid(p), typeid(Compiled));
|
||||
CHECK_EQ(outs[0].siblings()[0].id(), outs[1].id());
|
||||
CHECK(!outs[0].inputs()[0].has_primitive());
|
||||
CHECK(!outs[0].inputs()[1].has_primitive());
|
||||
}
|
||||
|
Reference in New Issue
Block a user