diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 17aa6a523..124126640 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -1,4 +1,5 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. + #include #include #include @@ -73,6 +74,10 @@ struct CompilerCache { cache_.erase(fun_id); } + void clear() { + cache_.clear(); + } + private: CompilerCache() {} friend CompilerCache& compiler_cache(); @@ -363,6 +368,10 @@ void compile_erase(size_t fun_id) { detail::compiler_cache().erase(fun_id); } +void compile_clear() { + detail::compiler_cache().clear(); +} + } // namespace detail std::function(const std::vector&)> compile( diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index 0616415c2..1f66089fb 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. namespace mlx::core::detail { @@ -23,6 +23,9 @@ std::function(const std::vector&)> compile( // Erase cached compile functions void compile_erase(size_t fun_id); +// Clear the compiler cache +void compile_clear(); + // Create an InTracing object during tracing operations to signify to the rest // of the codebase that we are during tracing so evals should not throw away // the graph. diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 07a621b5f..11ac1efc6 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -1,10 +1,9 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include #include #include #include -#include // TODO #include #include @@ -458,12 +457,24 @@ std::unordered_map& tree_cache() { struct PyCompiledFun { py::function fun; + size_t fun_id; + + PyCompiledFun(const py::function& fun) + : fun(fun), fun_id(reinterpret_cast(fun.ptr())) {} + + PyCompiledFun(const PyCompiledFun&) = delete; + PyCompiledFun& operator=(const PyCompiledFun&) = delete; + PyCompiledFun& operator=(PyCompiledFun&& other) = delete; + PyCompiledFun(PyCompiledFun&& other) { + fun = other.fun; + other.fun_id = 0; + fun_id = reinterpret_cast(fun.ptr()); + }; py::object operator()(const py::args& args) { // TODO, awni, I think this cast is ok?? - size_t fun_id = reinterpret_cast(fun.ptr()); - auto compile_fun = [fun_id, this, &args](const std::vector& a) { + auto compile_fun = [this, &args](const std::vector& a) { // Call the python function py::object py_outputs = this->fun(*tree_unflatten(args, a)); @@ -472,7 +483,7 @@ struct PyCompiledFun { py_outputs = tree_map(py_outputs, [](const py::handle& x) { return py::none(); }); - tree_cache().insert({fun_id, py_outputs}); + tree_cache().insert({this->fun_id, py_outputs}); return outputs; }; @@ -520,7 +531,6 @@ struct PyCompiledFun { }; ~PyCompiledFun() { - size_t fun_id = reinterpret_cast(fun.ptr()); detail::compile_erase(fun_id); } }; @@ -847,5 +857,8 @@ void init_transforms(py::module_& m) { // Register static Python object cleanup before the interpreter exits auto atexit = py::module_::import("atexit"); - atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); })); + atexit.attr("register")(py::cpp_function([]() { + detail::compile_clear(); + tree_cache().clear(); + })); } diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index c9961ea4d..ab3acf115 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -134,6 +134,18 @@ class TestCompile(mlx_tests.MLXTestCase): out = compiled(mx.array(1)) self.assertTrue(mx.array_equal(out, mx.array([-1, -2]))) + def test_function_creates_array(self): + def fun(x): + return x + mx.array(1) + + cfun = mx.compile(fun) + out = cfun(mx.array(3)) + self.assertEqual(out.item(), 4) + + # And again + out = cfun(mx.array(3)) + self.assertEqual(out.item(), 4) + if __name__ == "__main__": unittest.main()