diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 55304c6f02..e6748f505f 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -4,6 +4,7 @@ #include #include +#include "mlx/allocator.h" #include "mlx/primitives.h" #include "mlx/transforms.h" #include "mlx/transforms_impl.h" @@ -92,12 +93,12 @@ struct CompilerCache { cache_.erase(fun_id); } - void clear() { - cache_.clear(); - } - private: - CompilerCache() {} + CompilerCache() { + // Make sure the allocator is fully + // initialized before the compiler cache + allocator::allocator(); + } friend CompilerCache& compiler_cache(); std::unordered_map> cache_; }; @@ -400,10 +401,6 @@ void compile_erase(size_t fun_id) { detail::compiler_cache().erase(fun_id); } -void compile_clear() { - detail::compiler_cache().clear(); -} - } // namespace detail std::function(const std::vector&)> compile( diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index d0bd0de74d..dd5bdc8098 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -1,7 +1,6 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include -#include #include #include #include diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index 1f66089fbb..a1b3461ab4 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -23,9 +23,6 @@ std::function(const std::vector&)> compile( // Erase cached compile functions void compile_erase(size_t fun_id); -// Clear the compiler cache -void compile_clear(); - // Create an InTracing object during tracing operations to signify to the rest // of the codebase that we are during tracing so evals should not throw away // the graph. diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index d196feb553..6897f72e01 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -836,8 +836,5 @@ void init_transforms(py::module_& m) { // Register static Python object cleanup before the interpreter exits auto atexit = py::module_::import("atexit"); - atexit.attr("register")(py::cpp_function([]() { - detail::compile_clear(); - tree_cache().clear(); - })); + atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); })); } diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 30ac16dd1f..d21f2989d7 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -105,60 +105,98 @@ TEST_CASE("test enable and disable compile") { CHECK_THROWS(compile(nullptr)); } +auto add_scalars(const std::vector&) { + auto a = array(-1.0f); + auto b = array(-1.0f); + return std::vector{abs(a), abs(b)}; +}; + +auto max_scalars(const std::vector&) { + auto a = array({-1.0f, 2.0f}); + auto b = maximum(a, array(0.0f)); + auto c = maximum(-a, array(0.0f)); + auto d = b + c; + return std::vector{b, c, d}; +}; + TEST_CASE("test simplify scalars") { { - auto a = array(-1.0f); - auto b = array(-1.0f); - auto c = abs(a); - auto d = abs(b); - simplify({c, d}); + auto cfun = compile(add_scalars); + auto out = cfun({}); + auto c = out[0]; + auto d = out[1]; CHECK(c.inputs()[0].id() == d.inputs()[0].id()); } { auto a = array({-1.0f, 2.0f}); - auto b = maximum(a, array(0.0f)); - auto c = maximum(-a, array(0.0f)); - auto d = b + c; - simplify({d}); + auto out = compile(max_scalars)({a}); + auto b = out[0]; + auto c = out[1]; + auto d = out[2]; CHECK(b.inputs()[1].id() == c.inputs()[1].id()); } } -// TODO rework these tests for compile -/*TEST_CASE("test simplify") { +auto exp_two(const std::vector& inputs) { + auto a = inputs[0]; + return std::vector{exp(a) + exp(a)}; +}; + +TEST_CASE("test simplify") { auto a = array({1.0f, 2.0f}); - auto b = exp(a) + exp(a); - simplify(b); + auto b = compile(exp_two)({a})[0]; CHECK(b.inputs()[0].id() == b.inputs()[1].id()); } +auto add_diff(const std::vector& inputs) { + auto a = inputs[0]; + return std::vector{cos(a) + sin(a)}; +}; + TEST_CASE("test no simplify") { auto a = array({1.0f, 2.0f}); - auto b = cos(a) + sin(a); - simplify(b); + auto b = compile(add_diff)({a})[0]; CHECK(b.inputs()[0].id() != b.inputs()[1].id()); } +auto multi_one(const std::vector&) { + auto a = array(1.0); + auto b = array(2.0); + auto c = divmod(a, b); + auto d = divmod(a, b); + auto e = c[0] + d[0]; + auto f = c[1] + d[1]; + return std::vector{e, f}; +} + +auto multi_two(const std::vector&) { + auto a = array(1.0); + auto b = array(1.0); + auto c = divmod(a, b); + return std::vector{c}; +} + +auto multi_three(const std::vector&) { + auto a = array(1.0); + auto b = array(2.0); + auto c = divmod(a, b); + auto d = divmod(a, b); + auto e = stack({c[0], c[1], d[0], d[1]}); + return std::vector{e}; +} + TEST_CASE("test simplify multi output") { { - auto a = array(1.0); - auto b = array(2.0); - auto c = divmod(a, b); - auto d = divmod(a, b); - auto e = c[0] + d[0]; - auto f = c[1] + d[1]; - - simplify({e, f}); + auto out = compile(multi_one)({}); + auto e = out[0]; + auto f = out[1]; CHECK_EQ(e.inputs()[0].id(), e.inputs()[1].id()); CHECK_EQ(f.inputs()[0].id(), f.inputs()[1].id()); } { - auto a = array(1.0); - auto b = array(1.0); - auto c = divmod(a, b); - simplify(c); + auto c = compile(multi_two)({}); CHECK_EQ(c[0].inputs()[0].id(), c[0].inputs()[1].id()); CHECK_EQ(c[0].inputs()[0].id(), c[1].inputs()[0].id()); CHECK_EQ(c[1].inputs()[0].id(), c[1].inputs()[1].id()); @@ -167,14 +205,9 @@ TEST_CASE("test simplify multi output") { // Make sure the output order of multi-output primitives // is respected in simplification { - auto a = array(1.0); - auto b = array(2.0); - auto c = divmod(a, b); - auto d = divmod(a, b); - auto e = stack({c[0], c[1], d[0], d[1]}); - simplify(e); + auto e = compile(multi_three)({})[0]; CHECK(array_equal(e, array({0.0f, 1.0f, 0.0f, 1.0f})).item()); CHECK_EQ(e.inputs()[0].id(), e.inputs()[2].id()); CHECK_EQ(e.inputs()[1].id(), e.inputs()[3].id()); } -}*/ +}