mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
bug fix with move function and compile at exit
This commit is contained in:
parent
ecfb72157e
commit
df1f6c221b
@ -1,4 +1,5 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
@ -73,6 +74,10 @@ struct CompilerCache {
|
|||||||
cache_.erase(fun_id);
|
cache_.erase(fun_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void clear() {
|
||||||
|
cache_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CompilerCache() {}
|
CompilerCache() {}
|
||||||
friend CompilerCache& compiler_cache();
|
friend CompilerCache& compiler_cache();
|
||||||
@ -363,6 +368,10 @@ void compile_erase(size_t fun_id) {
|
|||||||
detail::compiler_cache().erase(fun_id);
|
detail::compiler_cache().erase(fun_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void compile_clear() {
|
||||||
|
detail::compiler_cache().clear();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
namespace mlx::core::detail {
|
namespace mlx::core::detail {
|
||||||
|
|
||||||
@ -23,6 +23,9 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
// Erase cached compile functions
|
// Erase cached compile functions
|
||||||
void compile_erase(size_t fun_id);
|
void compile_erase(size_t fun_id);
|
||||||
|
|
||||||
|
// Clear the compiler cache
|
||||||
|
void compile_clear();
|
||||||
|
|
||||||
// Create an InTracing object during tracing operations to signify to the rest
|
// Create an InTracing object during tracing operations to signify to the rest
|
||||||
// of the codebase that we are during tracing so evals should not throw away
|
// of the codebase that we are during tracing so evals should not throw away
|
||||||
// the graph.
|
// the graph.
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <pybind11/functional.h>
|
#include <pybind11/functional.h>
|
||||||
#include <pybind11/pybind11.h>
|
#include <pybind11/pybind11.h>
|
||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iostream> // TODO
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
@ -458,12 +457,24 @@ std::unordered_map<size_t, py::object>& tree_cache() {
|
|||||||
|
|
||||||
struct PyCompiledFun {
|
struct PyCompiledFun {
|
||||||
py::function fun;
|
py::function fun;
|
||||||
|
size_t fun_id;
|
||||||
|
|
||||||
|
PyCompiledFun(const py::function& fun)
|
||||||
|
: fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {}
|
||||||
|
|
||||||
|
PyCompiledFun(const PyCompiledFun&) = delete;
|
||||||
|
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
|
||||||
|
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
|
||||||
|
PyCompiledFun(PyCompiledFun&& other) {
|
||||||
|
fun = other.fun;
|
||||||
|
other.fun_id = 0;
|
||||||
|
fun_id = reinterpret_cast<size_t>(fun.ptr());
|
||||||
|
};
|
||||||
|
|
||||||
py::object operator()(const py::args& args) {
|
py::object operator()(const py::args& args) {
|
||||||
// TODO, awni, I think this cast is ok??
|
// TODO, awni, I think this cast is ok??
|
||||||
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
|
|
||||||
|
|
||||||
auto compile_fun = [fun_id, this, &args](const std::vector<array>& a) {
|
auto compile_fun = [this, &args](const std::vector<array>& a) {
|
||||||
// Call the python function
|
// Call the python function
|
||||||
py::object py_outputs = this->fun(*tree_unflatten(args, a));
|
py::object py_outputs = this->fun(*tree_unflatten(args, a));
|
||||||
|
|
||||||
@ -472,7 +483,7 @@ struct PyCompiledFun {
|
|||||||
|
|
||||||
py_outputs =
|
py_outputs =
|
||||||
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
|
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
|
||||||
tree_cache().insert({fun_id, py_outputs});
|
tree_cache().insert({this->fun_id, py_outputs});
|
||||||
return outputs;
|
return outputs;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -520,7 +531,6 @@ struct PyCompiledFun {
|
|||||||
};
|
};
|
||||||
|
|
||||||
~PyCompiledFun() {
|
~PyCompiledFun() {
|
||||||
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
|
|
||||||
detail::compile_erase(fun_id);
|
detail::compile_erase(fun_id);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -847,5 +857,8 @@ void init_transforms(py::module_& m) {
|
|||||||
|
|
||||||
// Register static Python object cleanup before the interpreter exits
|
// Register static Python object cleanup before the interpreter exits
|
||||||
auto atexit = py::module_::import("atexit");
|
auto atexit = py::module_::import("atexit");
|
||||||
atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); }));
|
atexit.attr("register")(py::cpp_function([]() {
|
||||||
|
detail::compile_clear();
|
||||||
|
tree_cache().clear();
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
@ -134,6 +134,18 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
out = compiled(mx.array(1))
|
out = compiled(mx.array(1))
|
||||||
self.assertTrue(mx.array_equal(out, mx.array([-1, -2])))
|
self.assertTrue(mx.array_equal(out, mx.array([-1, -2])))
|
||||||
|
|
||||||
|
def test_function_creates_array(self):
|
||||||
|
def fun(x):
|
||||||
|
return x + mx.array(1)
|
||||||
|
|
||||||
|
cfun = mx.compile(fun)
|
||||||
|
out = cfun(mx.array(3))
|
||||||
|
self.assertEqual(out.item(), 4)
|
||||||
|
|
||||||
|
# And again
|
||||||
|
out = cfun(mx.array(3))
|
||||||
|
self.assertEqual(out.item(), 4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user