diff --git a/mlx/compile.cpp b/mlx/compile.cpp index cfdd775d4..ab3be9b07 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -70,6 +70,10 @@ struct CompilerCache { return entries.back(); }; + void erase(size_t fun_id) { + cache_.erase(fun_id); + } + private: CompilerCache() {} friend CompilerCache& compiler_cache(); @@ -357,6 +361,11 @@ std::function(const std::vector&)> compile( return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs); }; } + +void compile_erase(size_t fun_id) { + detail::compiler_cache().erase(fun_id); +} + } // namespace detail std::function(const std::vector&)> compile( diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index 5b9c7ada5..0616415c2 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -20,6 +20,9 @@ std::function(const std::vector&)> compile( const std::function(const std::vector&)>& fun, size_t fun_id); +// Erase cached compile functions +void compile_erase(size_t fun_id); + // Create an InTracing object during tracing operations to signify to the rest // of the codebase that we are during tracing so evals should not throw away // the graph. diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index cf9c09266..07a621b5f 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -4,6 +4,7 @@ #include #include #include +#include // TODO #include #include @@ -455,14 +456,16 @@ std::unordered_map& tree_cache() { return tree_cache_; } -auto py_compile(const py::function& fun) { - return [fun](const py::args& args) { +struct PyCompiledFun { + py::function fun; + + py::object operator()(const py::args& args) { // TODO, awni, I think this cast is ok?? size_t fun_id = reinterpret_cast(fun.ptr()); - auto compile_fun = [fun_id, &fun, &args](const std::vector& a) { + auto compile_fun = [fun_id, this, &args](const std::vector& a) { // Call the python function - py::object py_outputs = fun(*tree_unflatten(args, a)); + py::object py_outputs = this->fun(*tree_unflatten(args, a)); // Flatten the outputs auto outputs = tree_flatten(py_outputs, true); @@ -477,12 +480,20 @@ auto py_compile(const py::function& fun) { auto inputs = tree_flatten(args, true); // Get globally enclosed arrays so we don't compile through them + // c.f. https://github.com/python/cpython/blob/main/Lib/inspect.py#L1638 if (py::hasattr(fun, "__globals__")) { - auto global_inputs = tree_flatten(py::getattr(fun, "__globals__"), false); - std::move( - std::begin(global_inputs), - std::end(global_inputs), - std::back_inserter(inputs)); + py::dict globals = py::getattr(fun, "__globals__"); + auto co_names = py::getattr(py::getattr(fun, "__code__"), "co_names"); + for (auto& n : co_names) { + if (py::cast(globals.attr("__contains__")(n))) { + auto global_inputs = + tree_flatten(globals.attr("__getitem__")(n), false); + std::move( + std::begin(global_inputs), + std::end(global_inputs), + std::back_inserter(inputs)); + } + } } // Get locally enclosed arrays so we don't compile through them @@ -507,7 +518,12 @@ auto py_compile(const py::function& fun) { py::object py_outputs = tree_cache().at(fun_id); return tree_unflatten_none(py_outputs, outputs); }; -} + + ~PyCompiledFun() { + size_t fun_id = reinterpret_cast(fun.ptr()); + detail::compile_erase(fun_id); + } +}; void init_transforms(py::module_& m) { py::options options; @@ -810,7 +826,9 @@ void init_transforms(py::module_& m) { "file"_a); m.def( "compile", - [](const py::function& fun) { return py::cpp_function(py_compile(fun)); }, + [](const py::function& fun) { + return py::cpp_function(PyCompiledFun{fun}); + }, "fun"_a, R"pbdoc( compile(fun: function) -> function