mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-10 03:06:39 +08:00
fix segfault on python exit
This commit is contained in:
parent
0005cfe053
commit
966b7faef4
@ -449,10 +449,13 @@ auto py_vmap(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
auto py_compile(const py::function& fun) {
|
std::unordered_map<size_t, py::object>& tree_cache() {
|
||||||
// This map is used to Cache the tree structure of the outputs
|
// This map is used to Cache the tree structure of the outputs
|
||||||
static std::unordered_map<size_t, py::object> tree_cache;
|
static std::unordered_map<size_t, py::object> tree_cache_;
|
||||||
|
return tree_cache_;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto py_compile(const py::function& fun) {
|
||||||
return [fun](const py::args& args) {
|
return [fun](const py::args& args) {
|
||||||
// Inputs must be array or tree of arrays
|
// Inputs must be array or tree of arrays
|
||||||
auto inputs = tree_flatten(args, true);
|
auto inputs = tree_flatten(args, true);
|
||||||
@ -470,7 +473,7 @@ auto py_compile(const py::function& fun) {
|
|||||||
|
|
||||||
py_outputs =
|
py_outputs =
|
||||||
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
|
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
|
||||||
tree_cache.insert({fun_id, py_outputs});
|
tree_cache().insert({fun_id, py_outputs});
|
||||||
return outputs;
|
return outputs;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -478,7 +481,7 @@ auto py_compile(const py::function& fun) {
|
|||||||
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
|
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
|
||||||
|
|
||||||
// Put the outputs back in the container
|
// Put the outputs back in the container
|
||||||
py::object py_outputs = tree_cache.at(fun_id);
|
py::object py_outputs = tree_cache().at(fun_id);
|
||||||
return tree_unflatten_none(py_outputs, outputs);
|
return tree_unflatten_none(py_outputs, outputs);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -800,4 +803,8 @@ void init_transforms(py::module_& m) {
|
|||||||
function: A compiled function which has the same input arguments
|
function: A compiled function which has the same input arguments
|
||||||
as ``fun`` and returns the the same output(s).
|
as ``fun`` and returns the the same output(s).
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
|
// Register static Python object cleanup before the interpreter exits
|
||||||
|
auto atexit = py::module_::import("atexit");
|
||||||
|
atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); }));
|
||||||
}
|
}
|
||||||
|
@ -22,6 +22,39 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
out = compiled_fn(x, y)
|
out = compiled_fn(x, y)
|
||||||
self.assertEqual(out.item(), 2.0)
|
self.assertEqual(out.item(), 2.0)
|
||||||
|
|
||||||
|
def test_compile_grad(self):
|
||||||
|
def loss_fn(x):
|
||||||
|
return mx.exp(x).sum()
|
||||||
|
|
||||||
|
grad_fn = mx.grad(loss_fn)
|
||||||
|
|
||||||
|
x = mx.array([0.5, -0.5, 1.2])
|
||||||
|
dfdx = grad_fn(x)
|
||||||
|
compile_grad_fn = mx.compile(grad_fn)
|
||||||
|
c_dfdx = grad_fn(x)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(c_dfdx, dfdx))
|
||||||
|
|
||||||
|
# Run it again without calling compile
|
||||||
|
c_dfdx = compile_grad_fn(x)
|
||||||
|
self.assertTrue(mx.allclose(c_dfdx, dfdx))
|
||||||
|
|
||||||
|
# Run it again with calling compile
|
||||||
|
c_dfdx = mx.compile(grad_fn)(x)
|
||||||
|
self.assertTrue(mx.allclose(c_dfdx, dfdx))
|
||||||
|
|
||||||
|
# Value and grad
|
||||||
|
def loss_fn(x):
|
||||||
|
return mx.exp(x).sum(), mx.sin(x)
|
||||||
|
|
||||||
|
val_and_grad_fn = mx.value_and_grad(loss_fn)
|
||||||
|
(loss, val), dfdx = val_and_grad_fn(x)
|
||||||
|
(c_loss, c_val), c_dfdx = mx.compile(val_and_grad_fn)(x)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(c_dfdx, dfdx))
|
||||||
|
self.assertTrue(mx.allclose(c_loss, loss))
|
||||||
|
self.assertTrue(mx.allclose(c_val, val))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user