diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 743af64ed..fa9e0a987 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -86,50 +86,20 @@ std::vector Compiled::vjp( const std::vector& cotangents, const std::vector& argnums, const std::vector& outputs) { - auto fun = [this](const std::vector& inputs) { - return detail::compile_replace(tape_, inputs_, outputs_, inputs); - }; - - auto [_, vjps] = mlx::core::vjp(fun, primals, cotangents); - std::vector vjp_outs; - for (int i = 0, j = 0; i < vjps.size(); ++i) { - if (i < argnums.size() && i == argnums[j]) { - vjp_outs.push_back(vjps[i]); - j++; - } - } - return vjp_outs; + throw std::runtime_error("[Compiled] Cannot vjp primitive."); } std::vector Compiled::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { - auto fun = [this](const std::vector& inputs) { - return detail::compile_replace(tape_, inputs_, outputs_, inputs); - }; - - auto [_, jvps] = mlx::core::jvp(fun, primals, tangents); - std::vector jvp_outs; - for (int i = 0, j = 0; i < jvps.size(); ++i) { - if (i < argnums.size() && i == argnums[j]) { - jvp_outs.push_back(jvps[i]); - j++; - } - } - return jvp_outs; + throw std::runtime_error("[Compiled] Cannot jvp primitive."); } std::pair, std::vector> Compiled::vmap( const std::vector& inputs, const std::vector& axes) { - auto fun = [this](const std::vector& inputs) { - return detail::compile_replace(tape_, inputs_, outputs_, inputs); - }; - - auto outputs = mlx::core::vmap(fun, axes)(inputs); - auto out_axes = std::vector(outputs.size(), 0); - return {outputs, out_axes}; + throw std::runtime_error("[Compiled] Cannot vmap primitive."); } bool Compiled::is_equivalent(const Primitive& other) const { @@ -734,6 +704,13 @@ std::function(const std::vector&)> compile( return fun; } return [fun, fun_id](const std::vector& inputs) { + // If the inputs are tracers, trace the original graph + if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) { + return in.is_tracer(); + })) { + return fun(inputs); + } + // Find a cache entry with the correct inputs auto& entry = compiler_cache().find(fun_id, inputs); diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 0d96ae23e..935a881a8 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -587,3 +587,39 @@ TEST_CASE("test compile repeat input") { auto expected_out = repeat_input_to_compiled({x})[0]; CHECK(allclose(out, expected_out).item()); } + +auto compile_unary_inner(const std::vector& inputs) { + auto x = inputs[0]; + return std::vector{exp(exp(x))}; +} + +auto compile_unary_outer(const std::vector& inputs) { + auto cfun = compile(compile_unary_inner); + return cfun(cfun(inputs)); +} + +TEST_CASE("test compile compiled function") { + auto cfun = compile(compile_unary_outer); + auto x = array({1.0f}); + auto out = cfun({x})[0]; + auto& p = out.primitive(); + CHECK_EQ(typeid(p), typeid(Compiled)); + CHECK_EQ(out.inputs()[0].id(), x.id()); +} + +auto grad_unary_compiled(const std::vector& inputs) { + auto gradfn = value_and_grad(compile(compile_unary_inner)); + auto [out, grad] = gradfn(inputs); + return std::vector{out[0], grad[0]}; +} + +TEST_CASE("test transform compiled function") { + auto cfun = compile(grad_unary_compiled); + auto x = array(1.0f); + auto outs = cfun({x}); + auto& p = outs[0].primitive(); + CHECK_EQ(typeid(p), typeid(Compiled)); + CHECK_EQ(outs[0].siblings()[0].id(), outs[1].id()); + CHECK(!outs[0].inputs()[0].has_primitive()); + CHECK(!outs[0].inputs()[1].has_primitive()); +}