bug fix with move function and compile at exit

This commit is contained in:
Awni Hannun 2024-01-16 11:18:36 -08:00
parent ecfb72157e
commit df1f6c221b
4 changed files with 46 additions and 9 deletions

View File

@ -1,4 +1,5 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <map>
#include <unordered_map>
#include <unordered_set>
@ -73,6 +74,10 @@ struct CompilerCache {
cache_.erase(fun_id);
}
void clear() {
cache_.clear();
}
private:
CompilerCache() {}
friend CompilerCache& compiler_cache();
@ -363,6 +368,10 @@ void compile_erase(size_t fun_id) {
detail::compiler_cache().erase(fun_id);
}
void compile_clear() {
detail::compiler_cache().clear();
}
} // namespace detail
std::function<std::vector<array>(const std::vector<array>&)> compile(

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
namespace mlx::core::detail {
@ -23,6 +23,9 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
// Erase cached compile functions
void compile_erase(size_t fun_id);
// Clear the compiler cache
void compile_clear();
// Create an InTracing object during tracing operations to signify to the rest
// of the codebase that we are during tracing so evals should not throw away
// the graph.

View File

@ -1,10 +1,9 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <algorithm>
#include <fstream>
#include <iostream> // TODO
#include <numeric>
#include <sstream>
@ -458,12 +457,24 @@ std::unordered_map<size_t, py::object>& tree_cache() {
struct PyCompiledFun {
py::function fun;
size_t fun_id;
PyCompiledFun(const py::function& fun)
: fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {}
PyCompiledFun(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
PyCompiledFun(PyCompiledFun&& other) {
fun = other.fun;
other.fun_id = 0;
fun_id = reinterpret_cast<size_t>(fun.ptr());
};
py::object operator()(const py::args& args) {
// TODO, awni, I think this cast is ok??
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
auto compile_fun = [fun_id, this, &args](const std::vector<array>& a) {
auto compile_fun = [this, &args](const std::vector<array>& a) {
// Call the python function
py::object py_outputs = this->fun(*tree_unflatten(args, a));
@ -472,7 +483,7 @@ struct PyCompiledFun {
py_outputs =
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
tree_cache().insert({fun_id, py_outputs});
tree_cache().insert({this->fun_id, py_outputs});
return outputs;
};
@ -520,7 +531,6 @@ struct PyCompiledFun {
};
~PyCompiledFun() {
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
detail::compile_erase(fun_id);
}
};
@ -847,5 +857,8 @@ void init_transforms(py::module_& m) {
// Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit");
atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); }));
atexit.attr("register")(py::cpp_function([]() {
detail::compile_clear();
tree_cache().clear();
}));
}

View File

@ -134,6 +134,18 @@ class TestCompile(mlx_tests.MLXTestCase):
out = compiled(mx.array(1))
self.assertTrue(mx.array_equal(out, mx.array([-1, -2])))
def test_function_creates_array(self):
def fun(x):
return x + mx.array(1)
cfun = mx.compile(fun)
out = cfun(mx.array(3))
self.assertEqual(out.item(), 4)
# And again
out = cfun(mx.array(3))
self.assertEqual(out.item(), 4)
if __name__ == "__main__":
unittest.main()