From 966b7faef4bdceef376e3da2e0576ce5c6ffaf63 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 15 Jan 2024 08:33:28 -0800 Subject: [PATCH] fix segfault on python exit --- python/src/transforms.cpp | 15 +++++++++++---- python/tests/test_compile.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 91d9f88030..bef7f8993f 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -449,10 +449,13 @@ auto py_vmap( }; } -auto py_compile(const py::function& fun) { +std::unordered_map& tree_cache() { // This map is used to Cache the tree structure of the outputs - static std::unordered_map tree_cache; + static std::unordered_map tree_cache_; + return tree_cache_; +} +auto py_compile(const py::function& fun) { return [fun](const py::args& args) { // Inputs must be array or tree of arrays auto inputs = tree_flatten(args, true); @@ -470,7 +473,7 @@ auto py_compile(const py::function& fun) { py_outputs = tree_map(py_outputs, [](const py::handle& x) { return py::none(); }); - tree_cache.insert({fun_id, py_outputs}); + tree_cache().insert({fun_id, py_outputs}); return outputs; }; @@ -478,7 +481,7 @@ auto py_compile(const py::function& fun) { auto outputs = detail::compile(compile_fun, fun_id)(inputs); // Put the outputs back in the container - py::object py_outputs = tree_cache.at(fun_id); + py::object py_outputs = tree_cache().at(fun_id); return tree_unflatten_none(py_outputs, outputs); }; } @@ -800,4 +803,8 @@ void init_transforms(py::module_& m) { function: A compiled function which has the same input arguments as ``fun`` and returns the the same output(s). )pbdoc"); + + // Register static Python object cleanup before the interpreter exits + auto atexit = py::module_::import("atexit"); + atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); })); } diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index f9b36414bc..a8bbb834b8 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -22,6 +22,39 @@ class TestCompile(mlx_tests.MLXTestCase): out = compiled_fn(x, y) self.assertEqual(out.item(), 2.0) + def test_compile_grad(self): + def loss_fn(x): + return mx.exp(x).sum() + + grad_fn = mx.grad(loss_fn) + + x = mx.array([0.5, -0.5, 1.2]) + dfdx = grad_fn(x) + compile_grad_fn = mx.compile(grad_fn) + c_dfdx = grad_fn(x) + + self.assertTrue(mx.allclose(c_dfdx, dfdx)) + + # Run it again without calling compile + c_dfdx = compile_grad_fn(x) + self.assertTrue(mx.allclose(c_dfdx, dfdx)) + + # Run it again with calling compile + c_dfdx = mx.compile(grad_fn)(x) + self.assertTrue(mx.allclose(c_dfdx, dfdx)) + + # Value and grad + def loss_fn(x): + return mx.exp(x).sum(), mx.sin(x) + + val_and_grad_fn = mx.value_and_grad(loss_fn) + (loss, val), dfdx = val_and_grad_fn(x) + (c_loss, c_val), c_dfdx = mx.compile(val_and_grad_fn)(x) + + self.assertTrue(mx.allclose(c_dfdx, dfdx)) + self.assertTrue(mx.allclose(c_loss, loss)) + self.assertTrue(mx.allclose(c_val, val)) + if __name__ == "__main__": unittest.main()