mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
bug fix with move function and compile at exit
This commit is contained in:
parent
ecfb72157e
commit
df1f6c221b
@ -1,4 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
@ -73,6 +74,10 @@ struct CompilerCache {
|
||||
cache_.erase(fun_id);
|
||||
}
|
||||
|
||||
void clear() {
|
||||
cache_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
CompilerCache() {}
|
||||
friend CompilerCache& compiler_cache();
|
||||
@ -363,6 +368,10 @@ void compile_erase(size_t fun_id) {
|
||||
detail::compiler_cache().erase(fun_id);
|
||||
}
|
||||
|
||||
void compile_clear() {
|
||||
detail::compiler_cache().clear();
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
||||
@ -23,6 +23,9 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
// Erase cached compile functions
|
||||
void compile_erase(size_t fun_id);
|
||||
|
||||
// Clear the compiler cache
|
||||
void compile_clear();
|
||||
|
||||
// Create an InTracing object during tracing operations to signify to the rest
|
||||
// of the codebase that we are during tracing so evals should not throw away
|
||||
// the graph.
|
||||
|
@ -1,10 +1,9 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iostream> // TODO
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
@ -458,12 +457,24 @@ std::unordered_map<size_t, py::object>& tree_cache() {
|
||||
|
||||
struct PyCompiledFun {
|
||||
py::function fun;
|
||||
size_t fun_id;
|
||||
|
||||
PyCompiledFun(const py::function& fun)
|
||||
: fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {}
|
||||
|
||||
PyCompiledFun(const PyCompiledFun&) = delete;
|
||||
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
|
||||
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
|
||||
PyCompiledFun(PyCompiledFun&& other) {
|
||||
fun = other.fun;
|
||||
other.fun_id = 0;
|
||||
fun_id = reinterpret_cast<size_t>(fun.ptr());
|
||||
};
|
||||
|
||||
py::object operator()(const py::args& args) {
|
||||
// TODO, awni, I think this cast is ok??
|
||||
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
|
||||
|
||||
auto compile_fun = [fun_id, this, &args](const std::vector<array>& a) {
|
||||
auto compile_fun = [this, &args](const std::vector<array>& a) {
|
||||
// Call the python function
|
||||
py::object py_outputs = this->fun(*tree_unflatten(args, a));
|
||||
|
||||
@ -472,7 +483,7 @@ struct PyCompiledFun {
|
||||
|
||||
py_outputs =
|
||||
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
|
||||
tree_cache().insert({fun_id, py_outputs});
|
||||
tree_cache().insert({this->fun_id, py_outputs});
|
||||
return outputs;
|
||||
};
|
||||
|
||||
@ -520,7 +531,6 @@ struct PyCompiledFun {
|
||||
};
|
||||
|
||||
~PyCompiledFun() {
|
||||
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
|
||||
detail::compile_erase(fun_id);
|
||||
}
|
||||
};
|
||||
@ -847,5 +857,8 @@ void init_transforms(py::module_& m) {
|
||||
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = py::module_::import("atexit");
|
||||
atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); }));
|
||||
atexit.attr("register")(py::cpp_function([]() {
|
||||
detail::compile_clear();
|
||||
tree_cache().clear();
|
||||
}));
|
||||
}
|
||||
|
@ -134,6 +134,18 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
out = compiled(mx.array(1))
|
||||
self.assertTrue(mx.array_equal(out, mx.array([-1, -2])))
|
||||
|
||||
def test_function_creates_array(self):
|
||||
def fun(x):
|
||||
return x + mx.array(1)
|
||||
|
||||
cfun = mx.compile(fun)
|
||||
out = cfun(mx.array(3))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
# And again
|
||||
out = cfun(mx.array(3))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user