mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	fix multi-output with compile
This commit is contained in:
		@@ -15,7 +15,7 @@ namespace detail {
 | 
			
		||||
 | 
			
		||||
bool& compiler_disabled() {
 | 
			
		||||
  auto get_val = []() {
 | 
			
		||||
    if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILER")) {
 | 
			
		||||
    if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
 | 
			
		||||
      return true;
 | 
			
		||||
    } else {
 | 
			
		||||
      return false;
 | 
			
		||||
@@ -343,13 +343,30 @@ std::vector<array> compile_replace(
 | 
			
		||||
    if (!a.has_primitive()) {
 | 
			
		||||
      trace_to_real.insert({a.id(), a});
 | 
			
		||||
    } else {
 | 
			
		||||
      // Find real inputs
 | 
			
		||||
      std::vector<array> real_inputs;
 | 
			
		||||
      for (auto& in : a.inputs()) {
 | 
			
		||||
        real_inputs.push_back(trace_to_real.at(in.id()));
 | 
			
		||||
      }
 | 
			
		||||
      if (a.siblings().empty()) {
 | 
			
		||||
        auto real_a = array(
 | 
			
		||||
            a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
 | 
			
		||||
        trace_to_real.insert({a.id(), std::move(real_a)});
 | 
			
		||||
      } else {
 | 
			
		||||
        // Ensure the order is correct for multi-output primitives
 | 
			
		||||
        std::vector<std::vector<int>> shapes;
 | 
			
		||||
        std::vector<Dtype> types;
 | 
			
		||||
        auto trace_out = a.outputs();
 | 
			
		||||
        for (auto& o : trace_out) {
 | 
			
		||||
          shapes.push_back(o.shape());
 | 
			
		||||
          types.push_back(o.dtype());
 | 
			
		||||
        }
 | 
			
		||||
        auto real_out =
 | 
			
		||||
            array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
 | 
			
		||||
        for (int i = 0; i < trace_out.size(); ++i) {
 | 
			
		||||
          trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])});
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@@ -412,11 +429,11 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
 | 
			
		||||
  return detail::compile(fun, fun_id);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void disable_compiler() {
 | 
			
		||||
void disable_compile() {
 | 
			
		||||
  detail::compiler_disabled() = true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void enable_compiler() {
 | 
			
		||||
void enable_compile() {
 | 
			
		||||
  detail::compiler_disabled() = false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -11,15 +11,15 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
 | 
			
		||||
    const std::function<std::vector<array>(const std::vector<array>&)>& fun);
 | 
			
		||||
 | 
			
		||||
/** Globally disable compilation.
 | 
			
		||||
 * Setting the environment variable ``MLX_DISABLE_COMPILER`` can also
 | 
			
		||||
 * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
 | 
			
		||||
 * be used to disable compilation.
 | 
			
		||||
 */
 | 
			
		||||
void disable_compiler();
 | 
			
		||||
void disable_compile();
 | 
			
		||||
 | 
			
		||||
/** Globally enable compilation.
 | 
			
		||||
 * This will override the environment variable ``MLX_DISABLE_COMPILER``.
 | 
			
		||||
 * This will override the environment variable ``MLX_DISABLE_COMPILE``.
 | 
			
		||||
 */
 | 
			
		||||
void enable_compiler();
 | 
			
		||||
void enable_compile();
 | 
			
		||||
 | 
			
		||||
void eval(const std::vector<array>& outputs);
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -816,22 +816,22 @@ void init_transforms(py::module_& m) {
 | 
			
		||||
            as ``fun`` and returns the the same output(s).
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "disable_compiler",
 | 
			
		||||
      &disable_compiler,
 | 
			
		||||
      "disable_compile",
 | 
			
		||||
      &disable_compile,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        disable_compiler() -> None
 | 
			
		||||
        disable_compile() -> None
 | 
			
		||||
 | 
			
		||||
        Globally disable compilation. Setting the environment variable
 | 
			
		||||
        ``MLX_DISABLE_COMPILER`` can also be used to disable compilation.
 | 
			
		||||
        ``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "enable_compiler",
 | 
			
		||||
      &enable_compiler,
 | 
			
		||||
      "enable_compile",
 | 
			
		||||
      &enable_compile,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        enable_compiler() -> None
 | 
			
		||||
 | 
			
		||||
        Globally enable compilation. This will override the environment
 | 
			
		||||
        variable ``MLX_DISABLE_COMPILER`` if set.
 | 
			
		||||
        variable ``MLX_DISABLE_COMPILE`` if set.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
 | 
			
		||||
  // Register static Python object cleanup before the interpreter exits
 | 
			
		||||
 
 | 
			
		||||
@@ -164,12 +164,12 @@ class TestCompile(mlx_tests.MLXTestCase):
 | 
			
		||||
        n_compiled = count_prims(cfun(x))
 | 
			
		||||
 | 
			
		||||
        # Check disabled
 | 
			
		||||
        mx.disable_compiler()
 | 
			
		||||
        mx.disable_compile()
 | 
			
		||||
        n_uncompiled = count_prims(cfun(x))
 | 
			
		||||
        self.assertTrue(n_compiled < n_uncompiled)
 | 
			
		||||
 | 
			
		||||
        # Check renabled
 | 
			
		||||
        mx.enable_compiler()
 | 
			
		||||
        mx.enable_compile()
 | 
			
		||||
        n_enable_compiled = count_prims(cfun(x))
 | 
			
		||||
        self.assertEqual(n_compiled, n_enable_compiled)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,6 @@
 | 
			
		||||
// Copyright © 2023 Apple Inc.
 | 
			
		||||
 | 
			
		||||
#include <iostream>
 | 
			
		||||
#include "doctest/doctest.h"
 | 
			
		||||
 | 
			
		||||
#include "mlx/mlx.h"
 | 
			
		||||
@@ -99,9 +100,9 @@ TEST_CASE("test nested compile") {
 | 
			
		||||
 | 
			
		||||
TEST_CASE("test enable and disable compile") {
 | 
			
		||||
  CHECK_THROWS(compile(nullptr));
 | 
			
		||||
  disable_compiler();
 | 
			
		||||
  disable_compile();
 | 
			
		||||
  compile(nullptr);
 | 
			
		||||
  enable_compiler();
 | 
			
		||||
  enable_compile();
 | 
			
		||||
  CHECK_THROWS(compile(nullptr));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user