mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	fix multi-output with compile
This commit is contained in:
		@@ -15,7 +15,7 @@ namespace detail {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
bool& compiler_disabled() {
 | 
					bool& compiler_disabled() {
 | 
				
			||||||
  auto get_val = []() {
 | 
					  auto get_val = []() {
 | 
				
			||||||
    if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILER")) {
 | 
					    if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
 | 
				
			||||||
      return true;
 | 
					      return true;
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      return false;
 | 
					      return false;
 | 
				
			||||||
@@ -343,13 +343,30 @@ std::vector<array> compile_replace(
 | 
				
			|||||||
    if (!a.has_primitive()) {
 | 
					    if (!a.has_primitive()) {
 | 
				
			||||||
      trace_to_real.insert({a.id(), a});
 | 
					      trace_to_real.insert({a.id(), a});
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
 | 
					      // Find real inputs
 | 
				
			||||||
      std::vector<array> real_inputs;
 | 
					      std::vector<array> real_inputs;
 | 
				
			||||||
      for (auto& in : a.inputs()) {
 | 
					      for (auto& in : a.inputs()) {
 | 
				
			||||||
        real_inputs.push_back(trace_to_real.at(in.id()));
 | 
					        real_inputs.push_back(trace_to_real.at(in.id()));
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      auto real_a = array(
 | 
					      if (a.siblings().empty()) {
 | 
				
			||||||
          a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
 | 
					        auto real_a = array(
 | 
				
			||||||
      trace_to_real.insert({a.id(), std::move(real_a)});
 | 
					            a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
 | 
				
			||||||
 | 
					        trace_to_real.insert({a.id(), std::move(real_a)});
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        // Ensure the order is correct for multi-output primitives
 | 
				
			||||||
 | 
					        std::vector<std::vector<int>> shapes;
 | 
				
			||||||
 | 
					        std::vector<Dtype> types;
 | 
				
			||||||
 | 
					        auto trace_out = a.outputs();
 | 
				
			||||||
 | 
					        for (auto& o : trace_out) {
 | 
				
			||||||
 | 
					          shapes.push_back(o.shape());
 | 
				
			||||||
 | 
					          types.push_back(o.dtype());
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        auto real_out =
 | 
				
			||||||
 | 
					            array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
 | 
				
			||||||
 | 
					        for (int i = 0; i < trace_out.size(); ++i) {
 | 
				
			||||||
 | 
					          trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])});
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -412,11 +429,11 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
 | 
				
			|||||||
  return detail::compile(fun, fun_id);
 | 
					  return detail::compile(fun, fun_id);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void disable_compiler() {
 | 
					void disable_compile() {
 | 
				
			||||||
  detail::compiler_disabled() = true;
 | 
					  detail::compiler_disabled() = true;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void enable_compiler() {
 | 
					void enable_compile() {
 | 
				
			||||||
  detail::compiler_disabled() = false;
 | 
					  detail::compiler_disabled() = false;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -11,15 +11,15 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
 | 
				
			|||||||
    const std::function<std::vector<array>(const std::vector<array>&)>& fun);
 | 
					    const std::function<std::vector<array>(const std::vector<array>&)>& fun);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/** Globally disable compilation.
 | 
					/** Globally disable compilation.
 | 
				
			||||||
 * Setting the environment variable ``MLX_DISABLE_COMPILER`` can also
 | 
					 * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
 | 
				
			||||||
 * be used to disable compilation.
 | 
					 * be used to disable compilation.
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
void disable_compiler();
 | 
					void disable_compile();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/** Globally enable compilation.
 | 
					/** Globally enable compilation.
 | 
				
			||||||
 * This will override the environment variable ``MLX_DISABLE_COMPILER``.
 | 
					 * This will override the environment variable ``MLX_DISABLE_COMPILE``.
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
void enable_compiler();
 | 
					void enable_compile();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void eval(const std::vector<array>& outputs);
 | 
					void eval(const std::vector<array>& outputs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -816,22 +816,22 @@ void init_transforms(py::module_& m) {
 | 
				
			|||||||
            as ``fun`` and returns the the same output(s).
 | 
					            as ``fun`` and returns the the same output(s).
 | 
				
			||||||
      )pbdoc");
 | 
					      )pbdoc");
 | 
				
			||||||
  m.def(
 | 
					  m.def(
 | 
				
			||||||
      "disable_compiler",
 | 
					      "disable_compile",
 | 
				
			||||||
      &disable_compiler,
 | 
					      &disable_compile,
 | 
				
			||||||
      R"pbdoc(
 | 
					      R"pbdoc(
 | 
				
			||||||
        disable_compiler() -> None
 | 
					        disable_compile() -> None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Globally disable compilation. Setting the environment variable
 | 
					        Globally disable compilation. Setting the environment variable
 | 
				
			||||||
        ``MLX_DISABLE_COMPILER`` can also be used to disable compilation.
 | 
					        ``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
 | 
				
			||||||
      )pbdoc");
 | 
					      )pbdoc");
 | 
				
			||||||
  m.def(
 | 
					  m.def(
 | 
				
			||||||
      "enable_compiler",
 | 
					      "enable_compile",
 | 
				
			||||||
      &enable_compiler,
 | 
					      &enable_compile,
 | 
				
			||||||
      R"pbdoc(
 | 
					      R"pbdoc(
 | 
				
			||||||
        enable_compiler() -> None
 | 
					        enable_compiler() -> None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Globally enable compilation. This will override the environment
 | 
					        Globally enable compilation. This will override the environment
 | 
				
			||||||
        variable ``MLX_DISABLE_COMPILER`` if set.
 | 
					        variable ``MLX_DISABLE_COMPILE`` if set.
 | 
				
			||||||
      )pbdoc");
 | 
					      )pbdoc");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Register static Python object cleanup before the interpreter exits
 | 
					  // Register static Python object cleanup before the interpreter exits
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -164,12 +164,12 @@ class TestCompile(mlx_tests.MLXTestCase):
 | 
				
			|||||||
        n_compiled = count_prims(cfun(x))
 | 
					        n_compiled = count_prims(cfun(x))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Check disabled
 | 
					        # Check disabled
 | 
				
			||||||
        mx.disable_compiler()
 | 
					        mx.disable_compile()
 | 
				
			||||||
        n_uncompiled = count_prims(cfun(x))
 | 
					        n_uncompiled = count_prims(cfun(x))
 | 
				
			||||||
        self.assertTrue(n_compiled < n_uncompiled)
 | 
					        self.assertTrue(n_compiled < n_uncompiled)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Check renabled
 | 
					        # Check renabled
 | 
				
			||||||
        mx.enable_compiler()
 | 
					        mx.enable_compile()
 | 
				
			||||||
        n_enable_compiled = count_prims(cfun(x))
 | 
					        n_enable_compiled = count_prims(cfun(x))
 | 
				
			||||||
        self.assertEqual(n_compiled, n_enable_compiled)
 | 
					        self.assertEqual(n_compiled, n_enable_compiled)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,5 +1,6 @@
 | 
				
			|||||||
// Copyright © 2023 Apple Inc.
 | 
					// Copyright © 2023 Apple Inc.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <iostream>
 | 
				
			||||||
#include "doctest/doctest.h"
 | 
					#include "doctest/doctest.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mlx/mlx.h"
 | 
					#include "mlx/mlx.h"
 | 
				
			||||||
@@ -99,9 +100,9 @@ TEST_CASE("test nested compile") {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
TEST_CASE("test enable and disable compile") {
 | 
					TEST_CASE("test enable and disable compile") {
 | 
				
			||||||
  CHECK_THROWS(compile(nullptr));
 | 
					  CHECK_THROWS(compile(nullptr));
 | 
				
			||||||
  disable_compiler();
 | 
					  disable_compile();
 | 
				
			||||||
  compile(nullptr);
 | 
					  compile(nullptr);
 | 
				
			||||||
  enable_compiler();
 | 
					  enable_compile();
 | 
				
			||||||
  CHECK_THROWS(compile(nullptr));
 | 
					  CHECK_THROWS(compile(nullptr));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user