fix multi-output with compile

This commit is contained in:
Awni Hannun 2024-01-17 06:55:11 -08:00
parent ed4d867092
commit 390e50b163
5 changed files with 39 additions and 21 deletions

View File

@ -15,7 +15,7 @@ namespace detail {
bool& compiler_disabled() { bool& compiler_disabled() {
auto get_val = []() { auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILER")) { if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
return true; return true;
} else { } else {
return false; return false;
@ -343,13 +343,30 @@ std::vector<array> compile_replace(
if (!a.has_primitive()) { if (!a.has_primitive()) {
trace_to_real.insert({a.id(), a}); trace_to_real.insert({a.id(), a});
} else { } else {
// Find real inputs
std::vector<array> real_inputs; std::vector<array> real_inputs;
for (auto& in : a.inputs()) { for (auto& in : a.inputs()) {
real_inputs.push_back(trace_to_real.at(in.id())); real_inputs.push_back(trace_to_real.at(in.id()));
} }
if (a.siblings().empty()) {
auto real_a = array( auto real_a = array(
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)); a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
trace_to_real.insert({a.id(), std::move(real_a)}); trace_to_real.insert({a.id(), std::move(real_a)});
} else {
// Ensure the order is correct for multi-output primitives
std::vector<std::vector<int>> shapes;
std::vector<Dtype> types;
auto trace_out = a.outputs();
for (auto& o : trace_out) {
shapes.push_back(o.shape());
types.push_back(o.dtype());
}
auto real_out =
array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
for (int i = 0; i < trace_out.size(); ++i) {
trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])});
}
}
} }
} }
@ -412,11 +429,11 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
return detail::compile(fun, fun_id); return detail::compile(fun, fun_id);
} }
void disable_compiler() { void disable_compile() {
detail::compiler_disabled() = true; detail::compiler_disabled() = true;
} }
void enable_compiler() { void enable_compile() {
detail::compiler_disabled() = false; detail::compiler_disabled() = false;
} }

View File

@ -11,15 +11,15 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
const std::function<std::vector<array>(const std::vector<array>&)>& fun); const std::function<std::vector<array>(const std::vector<array>&)>& fun);
/** Globally disable compilation. /** Globally disable compilation.
* Setting the environment variable ``MLX_DISABLE_COMPILER`` can also * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
* be used to disable compilation. * be used to disable compilation.
*/ */
void disable_compiler(); void disable_compile();
/** Globally enable compilation. /** Globally enable compilation.
* This will override the environment variable ``MLX_DISABLE_COMPILER``. * This will override the environment variable ``MLX_DISABLE_COMPILE``.
*/ */
void enable_compiler(); void enable_compile();
void eval(const std::vector<array>& outputs); void eval(const std::vector<array>& outputs);

View File

@ -816,22 +816,22 @@ void init_transforms(py::module_& m) {
as ``fun`` and returns the the same output(s). as ``fun`` and returns the the same output(s).
)pbdoc"); )pbdoc");
m.def( m.def(
"disable_compiler", "disable_compile",
&disable_compiler, &disable_compile,
R"pbdoc( R"pbdoc(
disable_compiler() -> None disable_compile() -> None
Globally disable compilation. Setting the environment variable Globally disable compilation. Setting the environment variable
``MLX_DISABLE_COMPILER`` can also be used to disable compilation. ``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
)pbdoc"); )pbdoc");
m.def( m.def(
"enable_compiler", "enable_compile",
&enable_compiler, &enable_compile,
R"pbdoc( R"pbdoc(
enable_compiler() -> None enable_compiler() -> None
Globally enable compilation. This will override the environment Globally enable compilation. This will override the environment
variable ``MLX_DISABLE_COMPILER`` if set. variable ``MLX_DISABLE_COMPILE`` if set.
)pbdoc"); )pbdoc");
// Register static Python object cleanup before the interpreter exits // Register static Python object cleanup before the interpreter exits

View File

@ -164,12 +164,12 @@ class TestCompile(mlx_tests.MLXTestCase):
n_compiled = count_prims(cfun(x)) n_compiled = count_prims(cfun(x))
# Check disabled # Check disabled
mx.disable_compiler() mx.disable_compile()
n_uncompiled = count_prims(cfun(x)) n_uncompiled = count_prims(cfun(x))
self.assertTrue(n_compiled < n_uncompiled) self.assertTrue(n_compiled < n_uncompiled)
# Check renabled # Check renabled
mx.enable_compiler() mx.enable_compile()
n_enable_compiled = count_prims(cfun(x)) n_enable_compiled = count_prims(cfun(x))
self.assertEqual(n_compiled, n_enable_compiled) self.assertEqual(n_compiled, n_enable_compiled)

View File

@ -1,5 +1,6 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <iostream>
#include "doctest/doctest.h" #include "doctest/doctest.h"
#include "mlx/mlx.h" #include "mlx/mlx.h"
@ -99,9 +100,9 @@ TEST_CASE("test nested compile") {
TEST_CASE("test enable and disable compile") { TEST_CASE("test enable and disable compile") {
CHECK_THROWS(compile(nullptr)); CHECK_THROWS(compile(nullptr));
disable_compiler(); disable_compile();
compile(nullptr); compile(nullptr);
enable_compiler(); enable_compile();
CHECK_THROWS(compile(nullptr)); CHECK_THROWS(compile(nullptr));
} }