From 5c78c16f1ce83c2c48b62a1fb048532ec238b7b8 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 16 Jan 2024 16:09:07 -0800 Subject: [PATCH] enable and disable compiler --- mlx/compile.cpp | 37 ++++++++++++++++++++++++++++++++---- mlx/transforms.h | 13 ++++++++++++- python/src/transforms.cpp | 18 ++++++++++++++++++ python/tests/test_compile.py | 27 ++++++++++++++++++++++++++ tests/compile_tests.cpp | 8 ++++++++ 5 files changed, 98 insertions(+), 5 deletions(-) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 0306fb229..55304c6f0 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -1,4 +1,5 @@ // Copyright © 2023-2024 Apple Inc. +#include #include #include #include @@ -11,6 +12,20 @@ namespace mlx::core { namespace detail { +bool& compiler_disabled() { + auto get_val = []() { + if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILER")) { + return true; + } else { + return false; + } + }; + static bool compiler_disabled_ = get_val(); + return compiler_disabled_; +} + +#define MAX_OPS_PER_BUFFER max_ops_per_buffer() + using CompileFn = std::function(const std::vector&)>; using ParentsMap = std::unordered_map>>; @@ -19,6 +34,10 @@ template size_t getAddress(std::function f) { typedef T(fnType)(U...); fnType** fnPointer = f.template target(); + if (fnPointer == nullptr) { + throw std::invalid_argument( + "[compile] Cannot compile a non-addressable function."); + } return (size_t)*fnPointer; } @@ -343,6 +362,9 @@ std::vector compile_replace( std::function(const std::vector&)> compile( const std::function(const std::vector&)>& fun, size_t fun_id) { + if (compiler_disabled()) { + return fun; + } return [fun, fun_id](const std::vector& inputs) { // Find a cache entry with the correct inputs auto& entry = compiler_cache().find(fun_id, inputs); @@ -386,12 +408,19 @@ void compile_clear() { std::function(const std::vector&)> compile( const std::function(const std::vector&)>& fun) { - auto fun_id = detail::getAddress(fun); - if (fun_id == 0) { - throw std::invalid_argument( - "[compile] Cannot compile a non-addressable function."); + if (detail::compiler_disabled()) { + return fun; } + auto fun_id = detail::getAddress(fun); return detail::compile(fun, fun_id); } +void disable_compiler() { + detail::compiler_disabled() = true; +} + +void enable_compiler() { + detail::compiler_disabled() = false; +} + } // namespace mlx::core diff --git a/mlx/transforms.h b/mlx/transforms.h index 539f47efd..340c40bc5 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once @@ -10,6 +10,17 @@ namespace mlx::core { std::function(const std::vector&)> compile( const std::function(const std::vector&)>& fun); +/** Globally disable compilation. + * Setting the environment variable ``MLX_DISABLE_COMPILER`` can also + * be used to disable compilation. + */ +void disable_compiler(); + +/** Globally enable compilation. + * This will override the environment variable ``MLX_DISABLE_COMPILER``. + */ +void enable_compiler(); + /** Fuse equivalent arrays to avoid duplicate execution. */ void simplify(const std::vector& outputs); diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 11ac1efc6..30553e263 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -854,6 +854,24 @@ void init_transforms(py::module_& m) { function: A compiled function which has the same input arguments as ``fun`` and returns the the same output(s). )pbdoc"); + m.def( + "disable_compiler", + &disable_compiler, + R"pbdoc( + disable_compiler() -> None + + Globally disable compilation. Setting the environment variable + ``MLX_DISABLE_COMPILER`` can also be used to disable compilation. + )pbdoc"); + m.def( + "enable_compiler", + &enable_compiler, + R"pbdoc( + enable_compiler() -> None + + Globally enable compilation. This will override the environment + variable ``MLX_DISABLE_COMPILER`` if set. + )pbdoc"); // Register static Python object cleanup before the interpreter exits auto atexit = py::module_::import("atexit"); diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index ab3acf115..f139d0a8e 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1,5 +1,6 @@ # Copyright © 2023-2024 Apple Inc. +import io import unittest import mlx.core as mx @@ -146,6 +147,32 @@ class TestCompile(mlx_tests.MLXTestCase): out = cfun(mx.array(3)) self.assertEqual(out.item(), 4) + def test_enable_disable(self): + def fun(x): + y = x + 1 + z = x + 1 + return y + z + + def count_prims(outputs): + buf = io.StringIO() + mx.export_to_dot(buf, outputs) + buf.seek(0) + return len([l for l in buf.read().split() if "label" in l]) + + x = mx.array(1.0) + cfun = mx.compile(fun) + n_compiled = count_prims(cfun(x)) + + # Check disabled + mx.disable_compiler() + n_uncompiled = count_prims(cfun(x)) + self.assertTrue(n_compiled < n_uncompiled) + + # Check renabled + mx.enable_compiler() + n_enable_compiled = count_prims(cfun(x)) + self.assertEqual(n_compiled, n_enable_compiled) + if __name__ == "__main__": unittest.main() diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index b012fff36..a6252e701 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -96,3 +96,11 @@ TEST_CASE("test nested compile") { out = cfun({array(1), array(2)})[0]; CHECK_EQ(out.item(), 9); } + +TEST_CASE("test enable and disable compile") { + CHECK_THROWS(compile(nullptr)); + disable_compiler(); + compile(nullptr); + enable_compiler(); + CHECK_THROWS(compile(nullptr)); +}