mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	fix multi-output with compile
This commit is contained in:
		| @@ -15,7 +15,7 @@ namespace detail { | ||||
|  | ||||
| bool& compiler_disabled() { | ||||
|   auto get_val = []() { | ||||
|     if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILER")) { | ||||
|     if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) { | ||||
|       return true; | ||||
|     } else { | ||||
|       return false; | ||||
| @@ -343,13 +343,30 @@ std::vector<array> compile_replace( | ||||
|     if (!a.has_primitive()) { | ||||
|       trace_to_real.insert({a.id(), a}); | ||||
|     } else { | ||||
|       // Find real inputs | ||||
|       std::vector<array> real_inputs; | ||||
|       for (auto& in : a.inputs()) { | ||||
|         real_inputs.push_back(trace_to_real.at(in.id())); | ||||
|       } | ||||
|       auto real_a = array( | ||||
|           a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)); | ||||
|       trace_to_real.insert({a.id(), std::move(real_a)}); | ||||
|       if (a.siblings().empty()) { | ||||
|         auto real_a = array( | ||||
|             a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)); | ||||
|         trace_to_real.insert({a.id(), std::move(real_a)}); | ||||
|       } else { | ||||
|         // Ensure the order is correct for multi-output primitives | ||||
|         std::vector<std::vector<int>> shapes; | ||||
|         std::vector<Dtype> types; | ||||
|         auto trace_out = a.outputs(); | ||||
|         for (auto& o : trace_out) { | ||||
|           shapes.push_back(o.shape()); | ||||
|           types.push_back(o.dtype()); | ||||
|         } | ||||
|         auto real_out = | ||||
|             array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs); | ||||
|         for (int i = 0; i < trace_out.size(); ++i) { | ||||
|           trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])}); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
| @@ -412,11 +429,11 @@ std::function<std::vector<array>(const std::vector<array>&)> compile( | ||||
|   return detail::compile(fun, fun_id); | ||||
| } | ||||
|  | ||||
| void disable_compiler() { | ||||
| void disable_compile() { | ||||
|   detail::compiler_disabled() = true; | ||||
| } | ||||
|  | ||||
| void enable_compiler() { | ||||
| void enable_compile() { | ||||
|   detail::compiler_disabled() = false; | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -11,15 +11,15 @@ std::function<std::vector<array>(const std::vector<array>&)> compile( | ||||
|     const std::function<std::vector<array>(const std::vector<array>&)>& fun); | ||||
|  | ||||
| /** Globally disable compilation. | ||||
|  * Setting the environment variable ``MLX_DISABLE_COMPILER`` can also | ||||
|  * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also | ||||
|  * be used to disable compilation. | ||||
|  */ | ||||
| void disable_compiler(); | ||||
| void disable_compile(); | ||||
|  | ||||
| /** Globally enable compilation. | ||||
|  * This will override the environment variable ``MLX_DISABLE_COMPILER``. | ||||
|  * This will override the environment variable ``MLX_DISABLE_COMPILE``. | ||||
|  */ | ||||
| void enable_compiler(); | ||||
| void enable_compile(); | ||||
|  | ||||
| void eval(const std::vector<array>& outputs); | ||||
|  | ||||
|   | ||||
| @@ -816,22 +816,22 @@ void init_transforms(py::module_& m) { | ||||
|             as ``fun`` and returns the the same output(s). | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "disable_compiler", | ||||
|       &disable_compiler, | ||||
|       "disable_compile", | ||||
|       &disable_compile, | ||||
|       R"pbdoc( | ||||
|         disable_compiler() -> None | ||||
|         disable_compile() -> None | ||||
|  | ||||
|         Globally disable compilation. Setting the environment variable | ||||
|         ``MLX_DISABLE_COMPILER`` can also be used to disable compilation. | ||||
|         ``MLX_DISABLE_COMPILE`` can also be used to disable compilation. | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "enable_compiler", | ||||
|       &enable_compiler, | ||||
|       "enable_compile", | ||||
|       &enable_compile, | ||||
|       R"pbdoc( | ||||
|         enable_compiler() -> None | ||||
|  | ||||
|         Globally enable compilation. This will override the environment | ||||
|         variable ``MLX_DISABLE_COMPILER`` if set. | ||||
|         variable ``MLX_DISABLE_COMPILE`` if set. | ||||
|       )pbdoc"); | ||||
|  | ||||
|   // Register static Python object cleanup before the interpreter exits | ||||
|   | ||||
| @@ -164,12 +164,12 @@ class TestCompile(mlx_tests.MLXTestCase): | ||||
|         n_compiled = count_prims(cfun(x)) | ||||
|  | ||||
|         # Check disabled | ||||
|         mx.disable_compiler() | ||||
|         mx.disable_compile() | ||||
|         n_uncompiled = count_prims(cfun(x)) | ||||
|         self.assertTrue(n_compiled < n_uncompiled) | ||||
|  | ||||
|         # Check renabled | ||||
|         mx.enable_compiler() | ||||
|         mx.enable_compile() | ||||
|         n_enable_compiled = count_prims(cfun(x)) | ||||
|         self.assertEqual(n_compiled, n_enable_compiled) | ||||
|  | ||||
|   | ||||
| @@ -1,5 +1,6 @@ | ||||
| // Copyright © 2023 Apple Inc. | ||||
|  | ||||
| #include <iostream> | ||||
| #include "doctest/doctest.h" | ||||
|  | ||||
| #include "mlx/mlx.h" | ||||
| @@ -99,9 +100,9 @@ TEST_CASE("test nested compile") { | ||||
|  | ||||
| TEST_CASE("test enable and disable compile") { | ||||
|   CHECK_THROWS(compile(nullptr)); | ||||
|   disable_compiler(); | ||||
|   disable_compile(); | ||||
|   compile(nullptr); | ||||
|   enable_compiler(); | ||||
|   enable_compile(); | ||||
|   CHECK_THROWS(compile(nullptr)); | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun