Skip compile when transforming (#635)

* skip compile when transforming

* simplify message
This commit is contained in:
Awni Hannun
2024-02-05 21:28:37 -08:00
committed by GitHub
parent 316ff490b3
commit 146bd69470
2 changed files with 46 additions and 33 deletions

View File

@@ -86,50 +86,20 @@ std::vector<array> Compiled::vjp(
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
auto fun = [this](const std::vector<array>& inputs) {
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
};
auto [_, vjps] = mlx::core::vjp(fun, primals, cotangents);
std::vector<array> vjp_outs;
for (int i = 0, j = 0; i < vjps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
vjp_outs.push_back(vjps[i]);
j++;
}
}
return vjp_outs;
throw std::runtime_error("[Compiled] Cannot vjp primitive.");
}
std::vector<array> Compiled::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
auto fun = [this](const std::vector<array>& inputs) {
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
};
auto [_, jvps] = mlx::core::jvp(fun, primals, tangents);
std::vector<array> jvp_outs;
for (int i = 0, j = 0; i < jvps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
jvp_outs.push_back(jvps[i]);
j++;
}
}
return jvp_outs;
throw std::runtime_error("[Compiled] Cannot jvp primitive.");
}
std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto fun = [this](const std::vector<array>& inputs) {
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
};
auto outputs = mlx::core::vmap(fun, axes)(inputs);
auto out_axes = std::vector<int>(outputs.size(), 0);
return {outputs, out_axes};
throw std::runtime_error("[Compiled] Cannot vmap primitive.");
}
bool Compiled::is_equivalent(const Primitive& other) const {
@@ -734,6 +704,13 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
return fun;
}
return [fun, fun_id](const std::vector<array>& inputs) {
// If the inputs are tracers, trace the original graph
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
return in.is_tracer();
})) {
return fun(inputs);
}
// Find a cache entry with the correct inputs
auto& entry = compiler_cache().find(fun_id, inputs);