mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Skip compile when transforming (#635)
* skip compile when transforming * simplify message
This commit is contained in:
parent
316ff490b3
commit
146bd69470
@ -86,50 +86,20 @@ std::vector<array> Compiled::vjp(
|
|||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) {
|
const std::vector<array>& outputs) {
|
||||||
auto fun = [this](const std::vector<array>& inputs) {
|
throw std::runtime_error("[Compiled] Cannot vjp primitive.");
|
||||||
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
|
|
||||||
};
|
|
||||||
|
|
||||||
auto [_, vjps] = mlx::core::vjp(fun, primals, cotangents);
|
|
||||||
std::vector<array> vjp_outs;
|
|
||||||
for (int i = 0, j = 0; i < vjps.size(); ++i) {
|
|
||||||
if (i < argnums.size() && i == argnums[j]) {
|
|
||||||
vjp_outs.push_back(vjps[i]);
|
|
||||||
j++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return vjp_outs;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> Compiled::jvp(
|
std::vector<array> Compiled::jvp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
auto fun = [this](const std::vector<array>& inputs) {
|
throw std::runtime_error("[Compiled] Cannot jvp primitive.");
|
||||||
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
|
|
||||||
};
|
|
||||||
|
|
||||||
auto [_, jvps] = mlx::core::jvp(fun, primals, tangents);
|
|
||||||
std::vector<array> jvp_outs;
|
|
||||||
for (int i = 0, j = 0; i < jvps.size(); ++i) {
|
|
||||||
if (i < argnums.size() && i == argnums[j]) {
|
|
||||||
jvp_outs.push_back(jvps[i]);
|
|
||||||
j++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return jvp_outs;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
auto fun = [this](const std::vector<array>& inputs) {
|
throw std::runtime_error("[Compiled] Cannot vmap primitive.");
|
||||||
return detail::compile_replace(tape_, inputs_, outputs_, inputs);
|
|
||||||
};
|
|
||||||
|
|
||||||
auto outputs = mlx::core::vmap(fun, axes)(inputs);
|
|
||||||
auto out_axes = std::vector<int>(outputs.size(), 0);
|
|
||||||
return {outputs, out_axes};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Compiled::is_equivalent(const Primitive& other) const {
|
bool Compiled::is_equivalent(const Primitive& other) const {
|
||||||
@ -734,6 +704,13 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
return fun;
|
return fun;
|
||||||
}
|
}
|
||||||
return [fun, fun_id](const std::vector<array>& inputs) {
|
return [fun, fun_id](const std::vector<array>& inputs) {
|
||||||
|
// If the inputs are tracers, trace the original graph
|
||||||
|
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
|
||||||
|
return in.is_tracer();
|
||||||
|
})) {
|
||||||
|
return fun(inputs);
|
||||||
|
}
|
||||||
|
|
||||||
// Find a cache entry with the correct inputs
|
// Find a cache entry with the correct inputs
|
||||||
auto& entry = compiler_cache().find(fun_id, inputs);
|
auto& entry = compiler_cache().find(fun_id, inputs);
|
||||||
|
|
||||||
|
@ -587,3 +587,39 @@ TEST_CASE("test compile repeat input") {
|
|||||||
auto expected_out = repeat_input_to_compiled({x})[0];
|
auto expected_out = repeat_input_to_compiled({x})[0];
|
||||||
CHECK(allclose(out, expected_out).item<bool>());
|
CHECK(allclose(out, expected_out).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto compile_unary_inner(const std::vector<array>& inputs) {
|
||||||
|
auto x = inputs[0];
|
||||||
|
return std::vector<array>{exp(exp(x))};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto compile_unary_outer(const std::vector<array>& inputs) {
|
||||||
|
auto cfun = compile(compile_unary_inner);
|
||||||
|
return cfun(cfun(inputs));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test compile compiled function") {
|
||||||
|
auto cfun = compile(compile_unary_outer);
|
||||||
|
auto x = array({1.0f});
|
||||||
|
auto out = cfun({x})[0];
|
||||||
|
auto& p = out.primitive();
|
||||||
|
CHECK_EQ(typeid(p), typeid(Compiled));
|
||||||
|
CHECK_EQ(out.inputs()[0].id(), x.id());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto grad_unary_compiled(const std::vector<array>& inputs) {
|
||||||
|
auto gradfn = value_and_grad(compile(compile_unary_inner));
|
||||||
|
auto [out, grad] = gradfn(inputs);
|
||||||
|
return std::vector{out[0], grad[0]};
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test transform compiled function") {
|
||||||
|
auto cfun = compile(grad_unary_compiled);
|
||||||
|
auto x = array(1.0f);
|
||||||
|
auto outs = cfun({x});
|
||||||
|
auto& p = outs[0].primitive();
|
||||||
|
CHECK_EQ(typeid(p), typeid(Compiled));
|
||||||
|
CHECK_EQ(outs[0].siblings()[0].id(), outs[1].id());
|
||||||
|
CHECK(!outs[0].inputs()[0].has_primitive());
|
||||||
|
CHECK(!outs[0].inputs()[1].has_primitive());
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user