mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	enable and disable compiler
This commit is contained in:
		@@ -1,4 +1,5 @@
 | 
			
		||||
// Copyright © 2023-2024 Apple Inc.
 | 
			
		||||
#include <cstdlib>
 | 
			
		||||
#include <map>
 | 
			
		||||
#include <unordered_map>
 | 
			
		||||
#include <unordered_set>
 | 
			
		||||
@@ -11,6 +12,20 @@ namespace mlx::core {
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
 | 
			
		||||
bool& compiler_disabled() {
 | 
			
		||||
  auto get_val = []() {
 | 
			
		||||
    if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILER")) {
 | 
			
		||||
      return true;
 | 
			
		||||
    } else {
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
  static bool compiler_disabled_ = get_val();
 | 
			
		||||
  return compiler_disabled_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
 | 
			
		||||
 | 
			
		||||
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
 | 
			
		||||
using ParentsMap =
 | 
			
		||||
    std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
 | 
			
		||||
@@ -19,6 +34,10 @@ template <typename T, typename... U>
 | 
			
		||||
size_t getAddress(std::function<T(U...)> f) {
 | 
			
		||||
  typedef T(fnType)(U...);
 | 
			
		||||
  fnType** fnPointer = f.template target<fnType*>();
 | 
			
		||||
  if (fnPointer == nullptr) {
 | 
			
		||||
    throw std::invalid_argument(
 | 
			
		||||
        "[compile] Cannot compile a non-addressable function.");
 | 
			
		||||
  }
 | 
			
		||||
  return (size_t)*fnPointer;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -343,6 +362,9 @@ std::vector<array> compile_replace(
 | 
			
		||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
 | 
			
		||||
    const std::function<std::vector<array>(const std::vector<array>&)>& fun,
 | 
			
		||||
    size_t fun_id) {
 | 
			
		||||
  if (compiler_disabled()) {
 | 
			
		||||
    return fun;
 | 
			
		||||
  }
 | 
			
		||||
  return [fun, fun_id](const std::vector<array>& inputs) {
 | 
			
		||||
    // Find a cache entry with the correct inputs
 | 
			
		||||
    auto& entry = compiler_cache().find(fun_id, inputs);
 | 
			
		||||
@@ -386,12 +408,19 @@ void compile_clear() {
 | 
			
		||||
 | 
			
		||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
 | 
			
		||||
    const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
 | 
			
		||||
  auto fun_id = detail::getAddress(fun);
 | 
			
		||||
  if (fun_id == 0) {
 | 
			
		||||
    throw std::invalid_argument(
 | 
			
		||||
        "[compile] Cannot compile a non-addressable function.");
 | 
			
		||||
  if (detail::compiler_disabled()) {
 | 
			
		||||
    return fun;
 | 
			
		||||
  }
 | 
			
		||||
  auto fun_id = detail::getAddress(fun);
 | 
			
		||||
  return detail::compile(fun, fun_id);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void disable_compiler() {
 | 
			
		||||
  detail::compiler_disabled() = true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void enable_compiler() {
 | 
			
		||||
  detail::compiler_disabled() = false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core
 | 
			
		||||
 
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
// Copyright © 2023 Apple Inc.
 | 
			
		||||
// Copyright © 2023-2024 Apple Inc.
 | 
			
		||||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
@@ -10,6 +10,17 @@ namespace mlx::core {
 | 
			
		||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
 | 
			
		||||
    const std::function<std::vector<array>(const std::vector<array>&)>& fun);
 | 
			
		||||
 | 
			
		||||
/** Globally disable compilation.
 | 
			
		||||
 * Setting the environment variable ``MLX_DISABLE_COMPILER`` can also
 | 
			
		||||
 * be used to disable compilation.
 | 
			
		||||
 */
 | 
			
		||||
void disable_compiler();
 | 
			
		||||
 | 
			
		||||
/** Globally enable compilation.
 | 
			
		||||
 * This will override the environment variable ``MLX_DISABLE_COMPILER``.
 | 
			
		||||
 */
 | 
			
		||||
void enable_compiler();
 | 
			
		||||
 | 
			
		||||
/** Fuse equivalent arrays to avoid duplicate execution. */
 | 
			
		||||
void simplify(const std::vector<array>& outputs);
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -854,6 +854,24 @@ void init_transforms(py::module_& m) {
 | 
			
		||||
            function: A compiled function which has the same input arguments
 | 
			
		||||
            as ``fun`` and returns the the same output(s).
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "disable_compiler",
 | 
			
		||||
      &disable_compiler,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        disable_compiler() -> None
 | 
			
		||||
 | 
			
		||||
        Globally disable compilation. Setting the environment variable
 | 
			
		||||
        ``MLX_DISABLE_COMPILER`` can also be used to disable compilation.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
  m.def(
 | 
			
		||||
      "enable_compiler",
 | 
			
		||||
      &enable_compiler,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        enable_compiler() -> None
 | 
			
		||||
 | 
			
		||||
        Globally enable compilation. This will override the environment
 | 
			
		||||
        variable ``MLX_DISABLE_COMPILER`` if set.
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
 | 
			
		||||
  // Register static Python object cleanup before the interpreter exits
 | 
			
		||||
  auto atexit = py::module_::import("atexit");
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,6 @@
 | 
			
		||||
# Copyright © 2023-2024 Apple Inc.
 | 
			
		||||
 | 
			
		||||
import io
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import mlx.core as mx
 | 
			
		||||
@@ -146,6 +147,32 @@ class TestCompile(mlx_tests.MLXTestCase):
 | 
			
		||||
        out = cfun(mx.array(3))
 | 
			
		||||
        self.assertEqual(out.item(), 4)
 | 
			
		||||
 | 
			
		||||
    def test_enable_disable(self):
 | 
			
		||||
        def fun(x):
 | 
			
		||||
            y = x + 1
 | 
			
		||||
            z = x + 1
 | 
			
		||||
            return y + z
 | 
			
		||||
 | 
			
		||||
        def count_prims(outputs):
 | 
			
		||||
            buf = io.StringIO()
 | 
			
		||||
            mx.export_to_dot(buf, outputs)
 | 
			
		||||
            buf.seek(0)
 | 
			
		||||
            return len([l for l in buf.read().split() if "label" in l])
 | 
			
		||||
 | 
			
		||||
        x = mx.array(1.0)
 | 
			
		||||
        cfun = mx.compile(fun)
 | 
			
		||||
        n_compiled = count_prims(cfun(x))
 | 
			
		||||
 | 
			
		||||
        # Check disabled
 | 
			
		||||
        mx.disable_compiler()
 | 
			
		||||
        n_uncompiled = count_prims(cfun(x))
 | 
			
		||||
        self.assertTrue(n_compiled < n_uncompiled)
 | 
			
		||||
 | 
			
		||||
        # Check renabled
 | 
			
		||||
        mx.enable_compiler()
 | 
			
		||||
        n_enable_compiled = count_prims(cfun(x))
 | 
			
		||||
        self.assertEqual(n_compiled, n_enable_compiled)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
 
 | 
			
		||||
@@ -96,3 +96,11 @@ TEST_CASE("test nested compile") {
 | 
			
		||||
  out = cfun({array(1), array(2)})[0];
 | 
			
		||||
  CHECK_EQ(out.item<int>(), 9);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_CASE("test enable and disable compile") {
 | 
			
		||||
  CHECK_THROWS(compile(nullptr));
 | 
			
		||||
  disable_compiler();
 | 
			
		||||
  compile(nullptr);
 | 
			
		||||
  enable_compiler();
 | 
			
		||||
  CHECK_THROWS(compile(nullptr));
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user