mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	bug fix with move function and compile at exit
This commit is contained in:
		@@ -1,4 +1,5 @@
 | 
			
		||||
// Copyright © 2023 Apple Inc.
 | 
			
		||||
// Copyright © 2023-2024 Apple Inc.
 | 
			
		||||
 | 
			
		||||
#include <map>
 | 
			
		||||
#include <unordered_map>
 | 
			
		||||
#include <unordered_set>
 | 
			
		||||
@@ -73,6 +74,10 @@ struct CompilerCache {
 | 
			
		||||
    cache_.erase(fun_id);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void clear() {
 | 
			
		||||
    cache_.clear();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  CompilerCache() {}
 | 
			
		||||
  friend CompilerCache& compiler_cache();
 | 
			
		||||
@@ -363,6 +368,10 @@ void compile_erase(size_t fun_id) {
 | 
			
		||||
  detail::compiler_cache().erase(fun_id);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void compile_clear() {
 | 
			
		||||
  detail::compiler_cache().clear();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace detail
 | 
			
		||||
 | 
			
		||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
 | 
			
		||||
 
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
// Copyright © 2023 Apple Inc.
 | 
			
		||||
// Copyright © 2023-2024 Apple Inc.
 | 
			
		||||
 | 
			
		||||
namespace mlx::core::detail {
 | 
			
		||||
 | 
			
		||||
@@ -23,6 +23,9 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
 | 
			
		||||
// Erase cached compile functions
 | 
			
		||||
void compile_erase(size_t fun_id);
 | 
			
		||||
 | 
			
		||||
// Clear the compiler cache
 | 
			
		||||
void compile_clear();
 | 
			
		||||
 | 
			
		||||
// Create an InTracing object during tracing operations to signify to the rest
 | 
			
		||||
// of the codebase that we are during tracing so evals should not throw away
 | 
			
		||||
// the graph.
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,9 @@
 | 
			
		||||
// Copyright © 2023 Apple Inc.
 | 
			
		||||
// Copyright © 2023-2024 Apple Inc.
 | 
			
		||||
#include <pybind11/functional.h>
 | 
			
		||||
#include <pybind11/pybind11.h>
 | 
			
		||||
#include <pybind11/stl.h>
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <fstream>
 | 
			
		||||
#include <iostream> // TODO
 | 
			
		||||
#include <numeric>
 | 
			
		||||
#include <sstream>
 | 
			
		||||
 | 
			
		||||
@@ -458,12 +457,24 @@ std::unordered_map<size_t, py::object>& tree_cache() {
 | 
			
		||||
 | 
			
		||||
struct PyCompiledFun {
 | 
			
		||||
  py::function fun;
 | 
			
		||||
  size_t fun_id;
 | 
			
		||||
 | 
			
		||||
  PyCompiledFun(const py::function& fun)
 | 
			
		||||
      : fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {}
 | 
			
		||||
 | 
			
		||||
  PyCompiledFun(const PyCompiledFun&) = delete;
 | 
			
		||||
  PyCompiledFun& operator=(const PyCompiledFun&) = delete;
 | 
			
		||||
  PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
 | 
			
		||||
  PyCompiledFun(PyCompiledFun&& other) {
 | 
			
		||||
    fun = other.fun;
 | 
			
		||||
    other.fun_id = 0;
 | 
			
		||||
    fun_id = reinterpret_cast<size_t>(fun.ptr());
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  py::object operator()(const py::args& args) {
 | 
			
		||||
    // TODO, awni, I think this cast is ok??
 | 
			
		||||
    size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
 | 
			
		||||
 | 
			
		||||
    auto compile_fun = [fun_id, this, &args](const std::vector<array>& a) {
 | 
			
		||||
    auto compile_fun = [this, &args](const std::vector<array>& a) {
 | 
			
		||||
      // Call the python function
 | 
			
		||||
      py::object py_outputs = this->fun(*tree_unflatten(args, a));
 | 
			
		||||
 | 
			
		||||
@@ -472,7 +483,7 @@ struct PyCompiledFun {
 | 
			
		||||
 | 
			
		||||
      py_outputs =
 | 
			
		||||
          tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
 | 
			
		||||
      tree_cache().insert({fun_id, py_outputs});
 | 
			
		||||
      tree_cache().insert({this->fun_id, py_outputs});
 | 
			
		||||
      return outputs;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
@@ -520,7 +531,6 @@ struct PyCompiledFun {
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  ~PyCompiledFun() {
 | 
			
		||||
    size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
 | 
			
		||||
    detail::compile_erase(fun_id);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
@@ -847,5 +857,8 @@ void init_transforms(py::module_& m) {
 | 
			
		||||
 | 
			
		||||
  // Register static Python object cleanup before the interpreter exits
 | 
			
		||||
  auto atexit = py::module_::import("atexit");
 | 
			
		||||
  atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); }));
 | 
			
		||||
  atexit.attr("register")(py::cpp_function([]() {
 | 
			
		||||
    detail::compile_clear();
 | 
			
		||||
    tree_cache().clear();
 | 
			
		||||
  }));
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -134,6 +134,18 @@ class TestCompile(mlx_tests.MLXTestCase):
 | 
			
		||||
        out = compiled(mx.array(1))
 | 
			
		||||
        self.assertTrue(mx.array_equal(out, mx.array([-1, -2])))
 | 
			
		||||
 | 
			
		||||
    def test_function_creates_array(self):
 | 
			
		||||
        def fun(x):
 | 
			
		||||
            return x + mx.array(1)
 | 
			
		||||
 | 
			
		||||
        cfun = mx.compile(fun)
 | 
			
		||||
        out = cfun(mx.array(3))
 | 
			
		||||
        self.assertEqual(out.item(), 4)
 | 
			
		||||
 | 
			
		||||
        # And again
 | 
			
		||||
        out = cfun(mx.array(3))
 | 
			
		||||
        self.assertEqual(out.item(), 4)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user