diff --git a/mlx/compile.cpp b/mlx/compile.cpp index e6748f505..074453ad2 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -15,7 +15,7 @@ namespace detail { bool& compiler_disabled() { auto get_val = []() { - if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILER")) { + if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) { return true; } else { return false; @@ -343,13 +343,30 @@ std::vector compile_replace( if (!a.has_primitive()) { trace_to_real.insert({a.id(), a}); } else { + // Find real inputs std::vector real_inputs; for (auto& in : a.inputs()) { real_inputs.push_back(trace_to_real.at(in.id())); } - auto real_a = array( - a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)); - trace_to_real.insert({a.id(), std::move(real_a)}); + if (a.siblings().empty()) { + auto real_a = array( + a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)); + trace_to_real.insert({a.id(), std::move(real_a)}); + } else { + // Ensure the order is correct for multi-output primitives + std::vector> shapes; + std::vector types; + auto trace_out = a.outputs(); + for (auto& o : trace_out) { + shapes.push_back(o.shape()); + types.push_back(o.dtype()); + } + auto real_out = + array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs); + for (int i = 0; i < trace_out.size(); ++i) { + trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])}); + } + } } } @@ -412,11 +429,11 @@ std::function(const std::vector&)> compile( return detail::compile(fun, fun_id); } -void disable_compiler() { +void disable_compile() { detail::compiler_disabled() = true; } -void enable_compiler() { +void enable_compile() { detail::compiler_disabled() = false; } diff --git a/mlx/transforms.h b/mlx/transforms.h index b380229ec..7a13b8f7e 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -11,15 +11,15 @@ std::function(const std::vector&)> compile( const std::function(const std::vector&)>& fun); /** Globally disable compilation. - * Setting the environment variable ``MLX_DISABLE_COMPILER`` can also + * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also * be used to disable compilation. */ -void disable_compiler(); +void disable_compile(); /** Globally enable compilation. - * This will override the environment variable ``MLX_DISABLE_COMPILER``. + * This will override the environment variable ``MLX_DISABLE_COMPILE``. */ -void enable_compiler(); +void enable_compile(); void eval(const std::vector& outputs); diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 6897f72e0..ca1931ec3 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -816,22 +816,22 @@ void init_transforms(py::module_& m) { as ``fun`` and returns the the same output(s). )pbdoc"); m.def( - "disable_compiler", - &disable_compiler, + "disable_compile", + &disable_compile, R"pbdoc( - disable_compiler() -> None + disable_compile() -> None Globally disable compilation. Setting the environment variable - ``MLX_DISABLE_COMPILER`` can also be used to disable compilation. + ``MLX_DISABLE_COMPILE`` can also be used to disable compilation. )pbdoc"); m.def( - "enable_compiler", - &enable_compiler, + "enable_compile", + &enable_compile, R"pbdoc( enable_compiler() -> None Globally enable compilation. This will override the environment - variable ``MLX_DISABLE_COMPILER`` if set. + variable ``MLX_DISABLE_COMPILE`` if set. )pbdoc"); // Register static Python object cleanup before the interpreter exits diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index f139d0a8e..83c63185c 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -164,12 +164,12 @@ class TestCompile(mlx_tests.MLXTestCase): n_compiled = count_prims(cfun(x)) # Check disabled - mx.disable_compiler() + mx.disable_compile() n_uncompiled = count_prims(cfun(x)) self.assertTrue(n_compiled < n_uncompiled) # Check renabled - mx.enable_compiler() + mx.enable_compile() n_enable_compiled = count_prims(cfun(x)) self.assertEqual(n_compiled, n_enable_compiled) diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index d21f2989d..5237b7196 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -1,5 +1,6 @@ // Copyright © 2023 Apple Inc. +#include #include "doctest/doctest.h" #include "mlx/mlx.h" @@ -99,9 +100,9 @@ TEST_CASE("test nested compile") { TEST_CASE("test enable and disable compile") { CHECK_THROWS(compile(nullptr)); - disable_compiler(); + disable_compile(); compile(nullptr); - enable_compiler(); + enable_compile(); CHECK_THROWS(compile(nullptr)); }