mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Skip compile when transforming (#635)
* skip compile when transforming * simplify message
This commit is contained in:
@@ -86,50 +86,20 @@ std::vector<array> Compiled::vjp(
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
auto fun = [this](const std::vector<array>& inputs) {
|
||||
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
|
||||
};
|
||||
|
||||
auto [_, vjps] = mlx::core::vjp(fun, primals, cotangents);
|
||||
std::vector<array> vjp_outs;
|
||||
for (int i = 0, j = 0; i < vjps.size(); ++i) {
|
||||
if (i < argnums.size() && i == argnums[j]) {
|
||||
vjp_outs.push_back(vjps[i]);
|
||||
j++;
|
||||
}
|
||||
}
|
||||
return vjp_outs;
|
||||
throw std::runtime_error("[Compiled] Cannot vjp primitive.");
|
||||
}
|
||||
|
||||
std::vector<array> Compiled::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
auto fun = [this](const std::vector<array>& inputs) {
|
||||
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
|
||||
};
|
||||
|
||||
auto [_, jvps] = mlx::core::jvp(fun, primals, tangents);
|
||||
std::vector<array> jvp_outs;
|
||||
for (int i = 0, j = 0; i < jvps.size(); ++i) {
|
||||
if (i < argnums.size() && i == argnums[j]) {
|
||||
jvp_outs.push_back(jvps[i]);
|
||||
j++;
|
||||
}
|
||||
}
|
||||
return jvp_outs;
|
||||
throw std::runtime_error("[Compiled] Cannot jvp primitive.");
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto fun = [this](const std::vector<array>& inputs) {
|
||||
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
|
||||
};
|
||||
|
||||
auto outputs = mlx::core::vmap(fun, axes)(inputs);
|
||||
auto out_axes = std::vector<int>(outputs.size(), 0);
|
||||
return {outputs, out_axes};
|
||||
throw std::runtime_error("[Compiled] Cannot vmap primitive.");
|
||||
}
|
||||
|
||||
bool Compiled::is_equivalent(const Primitive& other) const {
|
||||
@@ -734,6 +704,13 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
return fun;
|
||||
}
|
||||
return [fun, fun_id](const std::vector<array>& inputs) {
|
||||
// If the inputs are tracers, trace the original graph
|
||||
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
|
||||
return in.is_tracer();
|
||||
})) {
|
||||
return fun(inputs);
|
||||
}
|
||||
|
||||
// Find a cache entry with the correct inputs
|
||||
auto& entry = compiler_cache().find(fun_id, inputs);
|
||||
|
||||
|
Reference in New Issue
Block a user