bug fix with move function and compile at exit

This commit is contained in:
Awni Hannun
2024-01-16 11:18:36 -08:00
parent ecfb72157e
commit df1f6c221b
4 changed files with 46 additions and 9 deletions

View File

@@ -1,10 +1,9 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <algorithm>
#include <fstream>
#include <iostream> // TODO
#include <numeric>
#include <sstream>
@@ -458,12 +457,24 @@ std::unordered_map<size_t, py::object>& tree_cache() {
struct PyCompiledFun {
py::function fun;
size_t fun_id;
PyCompiledFun(const py::function& fun)
: fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {}
PyCompiledFun(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
PyCompiledFun(PyCompiledFun&& other) {
fun = other.fun;
other.fun_id = 0;
fun_id = reinterpret_cast<size_t>(fun.ptr());
};
py::object operator()(const py::args& args) {
// TODO, awni, I think this cast is ok??
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
auto compile_fun = [fun_id, this, &args](const std::vector<array>& a) {
auto compile_fun = [this, &args](const std::vector<array>& a) {
// Call the python function
py::object py_outputs = this->fun(*tree_unflatten(args, a));
@@ -472,7 +483,7 @@ struct PyCompiledFun {
py_outputs =
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
tree_cache().insert({fun_id, py_outputs});
tree_cache().insert({this->fun_id, py_outputs});
return outputs;
};
@@ -520,7 +531,6 @@ struct PyCompiledFun {
};
~PyCompiledFun() {
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
detail::compile_erase(fun_id);
}
};
@@ -847,5 +857,8 @@ void init_transforms(py::module_& m) {
// Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit");
atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); }));
atexit.attr("register")(py::cpp_function([]() {
detail::compile_clear();
tree_cache().clear();
}));
}

View File

@@ -134,6 +134,18 @@ class TestCompile(mlx_tests.MLXTestCase):
out = compiled(mx.array(1))
self.assertTrue(mx.array_equal(out, mx.array([-1, -2])))
def test_function_creates_array(self):
def fun(x):
return x + mx.array(1)
cfun = mx.compile(fun)
out = cfun(mx.array(3))
self.assertEqual(out.item(), 4)
# And again
out = cfun(mx.array(3))
self.assertEqual(out.item(), 4)
if __name__ == "__main__":
unittest.main()