Skip compile when transforming (#635)

* skip compile when transforming

* simplify message
This commit is contained in:
Awni Hannun 2024-02-05 21:28:37 -08:00 committed by GitHub
parent 316ff490b3
commit 146bd69470
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 33 deletions

View File

@ -86,50 +86,20 @@ std::vector<array> Compiled::vjp(
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
auto fun = [this](const std::vector<array>& inputs) {
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
};
auto [_, vjps] = mlx::core::vjp(fun, primals, cotangents);
std::vector<array> vjp_outs;
for (int i = 0, j = 0; i < vjps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
vjp_outs.push_back(vjps[i]);
j++;
}
}
return vjp_outs;
throw std::runtime_error("[Compiled] Cannot vjp primitive.");
}
std::vector<array> Compiled::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
auto fun = [this](const std::vector<array>& inputs) {
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
};
auto [_, jvps] = mlx::core::jvp(fun, primals, tangents);
std::vector<array> jvp_outs;
for (int i = 0, j = 0; i < jvps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
jvp_outs.push_back(jvps[i]);
j++;
}
}
return jvp_outs;
throw std::runtime_error("[Compiled] Cannot jvp primitive.");
}
std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto fun = [this](const std::vector<array>& inputs) {
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
};
auto outputs = mlx::core::vmap(fun, axes)(inputs);
auto out_axes = std::vector<int>(outputs.size(), 0);
return {outputs, out_axes};
throw std::runtime_error("[Compiled] Cannot vmap primitive.");
}
bool Compiled::is_equivalent(const Primitive& other) const {
@ -734,6 +704,13 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
return fun;
}
return [fun, fun_id](const std::vector<array>& inputs) {
// If the inputs are tracers, trace the original graph
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
return in.is_tracer();
})) {
return fun(inputs);
}
// Find a cache entry with the correct inputs
auto& entry = compiler_cache().find(fun_id, inputs);

View File

@ -587,3 +587,39 @@ TEST_CASE("test compile repeat input") {
auto expected_out = repeat_input_to_compiled({x})[0];
CHECK(allclose(out, expected_out).item<bool>());
}
auto compile_unary_inner(const std::vector<array>& inputs) {
auto x = inputs[0];
return std::vector<array>{exp(exp(x))};
}
auto compile_unary_outer(const std::vector<array>& inputs) {
auto cfun = compile(compile_unary_inner);
return cfun(cfun(inputs));
}
TEST_CASE("test compile compiled function") {
auto cfun = compile(compile_unary_outer);
auto x = array({1.0f});
auto out = cfun({x})[0];
auto& p = out.primitive();
CHECK_EQ(typeid(p), typeid(Compiled));
CHECK_EQ(out.inputs()[0].id(), x.id());
}
auto grad_unary_compiled(const std::vector<array>& inputs) {
auto gradfn = value_and_grad(compile(compile_unary_inner));
auto [out, grad] = gradfn(inputs);
return std::vector{out[0], grad[0]};
}
TEST_CASE("test transform compiled function") {
auto cfun = compile(grad_unary_compiled);
auto x = array(1.0f);
auto outs = cfun({x});
auto& p = outs[0].primitive();
CHECK_EQ(typeid(p), typeid(Compiled));
CHECK_EQ(outs[0].siblings()[0].id(), outs[1].id());
CHECK(!outs[0].inputs()[0].has_primitive());
CHECK(!outs[0].inputs()[1].has_primitive());
}