mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 02:28:13 +08:00 
			
		
		
		
	fix segfault on python exit
This commit is contained in:
		@@ -449,10 +449,13 @@ auto py_vmap(
 | 
			
		||||
  };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
auto py_compile(const py::function& fun) {
 | 
			
		||||
std::unordered_map<size_t, py::object>& tree_cache() {
 | 
			
		||||
  // This map is used to Cache the tree structure of the outputs
 | 
			
		||||
  static std::unordered_map<size_t, py::object> tree_cache;
 | 
			
		||||
  static std::unordered_map<size_t, py::object> tree_cache_;
 | 
			
		||||
  return tree_cache_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
auto py_compile(const py::function& fun) {
 | 
			
		||||
  return [fun](const py::args& args) {
 | 
			
		||||
    // Inputs must be array or tree of arrays
 | 
			
		||||
    auto inputs = tree_flatten(args, true);
 | 
			
		||||
@@ -470,7 +473,7 @@ auto py_compile(const py::function& fun) {
 | 
			
		||||
 | 
			
		||||
      py_outputs =
 | 
			
		||||
          tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
 | 
			
		||||
      tree_cache.insert({fun_id, py_outputs});
 | 
			
		||||
      tree_cache().insert({fun_id, py_outputs});
 | 
			
		||||
      return outputs;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
@@ -478,7 +481,7 @@ auto py_compile(const py::function& fun) {
 | 
			
		||||
    auto outputs = detail::compile(compile_fun, fun_id)(inputs);
 | 
			
		||||
 | 
			
		||||
    // Put the outputs back in the container
 | 
			
		||||
    py::object py_outputs = tree_cache.at(fun_id);
 | 
			
		||||
    py::object py_outputs = tree_cache().at(fun_id);
 | 
			
		||||
    return tree_unflatten_none(py_outputs, outputs);
 | 
			
		||||
  };
 | 
			
		||||
}
 | 
			
		||||
@@ -800,4 +803,8 @@ void init_transforms(py::module_& m) {
 | 
			
		||||
            function: A compiled function which has the same input arguments
 | 
			
		||||
            as ``fun`` and returns the the same output(s).
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
 | 
			
		||||
  // Register static Python object cleanup before the interpreter exits
 | 
			
		||||
  auto atexit = py::module_::import("atexit");
 | 
			
		||||
  atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); }));
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -22,6 +22,39 @@ class TestCompile(mlx_tests.MLXTestCase):
 | 
			
		||||
        out = compiled_fn(x, y)
 | 
			
		||||
        self.assertEqual(out.item(), 2.0)
 | 
			
		||||
 | 
			
		||||
    def test_compile_grad(self):
 | 
			
		||||
        def loss_fn(x):
 | 
			
		||||
            return mx.exp(x).sum()
 | 
			
		||||
 | 
			
		||||
        grad_fn = mx.grad(loss_fn)
 | 
			
		||||
 | 
			
		||||
        x = mx.array([0.5, -0.5, 1.2])
 | 
			
		||||
        dfdx = grad_fn(x)
 | 
			
		||||
        compile_grad_fn = mx.compile(grad_fn)
 | 
			
		||||
        c_dfdx = grad_fn(x)
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(mx.allclose(c_dfdx, dfdx))
 | 
			
		||||
 | 
			
		||||
        # Run it again without calling compile
 | 
			
		||||
        c_dfdx = compile_grad_fn(x)
 | 
			
		||||
        self.assertTrue(mx.allclose(c_dfdx, dfdx))
 | 
			
		||||
 | 
			
		||||
        # Run it again with calling compile
 | 
			
		||||
        c_dfdx = mx.compile(grad_fn)(x)
 | 
			
		||||
        self.assertTrue(mx.allclose(c_dfdx, dfdx))
 | 
			
		||||
 | 
			
		||||
        # Value and grad
 | 
			
		||||
        def loss_fn(x):
 | 
			
		||||
            return mx.exp(x).sum(), mx.sin(x)
 | 
			
		||||
 | 
			
		||||
        val_and_grad_fn = mx.value_and_grad(loss_fn)
 | 
			
		||||
        (loss, val), dfdx = val_and_grad_fn(x)
 | 
			
		||||
        (c_loss, c_val), c_dfdx = mx.compile(val_and_grad_fn)(x)
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(mx.allclose(c_dfdx, dfdx))
 | 
			
		||||
        self.assertTrue(mx.allclose(c_loss, loss))
 | 
			
		||||
        self.assertTrue(mx.allclose(c_val, val))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user