mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 06:44:40 +08:00
Compile front-end (#476)
* fix tests for linux * make a move on compile * basic compile scaffold works * compile binding * clean * fix * fix grad, more tests * basic python tests * fix segfault on python exit * compile works with python closures * fix test * fix python globals bug, and erase * simplify * more cpp tests * bug fix with move function and compile at exit * simplify inputs also * enable and disable compiler * remove simplify * simplify tests use compile now * fix multi-output with compile * clear output tree from cache when function goes out of scope * ../python/src/transforms.cpp * remove closure capture * comments
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
@@ -163,6 +162,19 @@ py::object tree_unflatten(
|
||||
});
|
||||
}
|
||||
|
||||
py::object tree_unflatten_none(
|
||||
py::object tree,
|
||||
const std::vector<array>& values,
|
||||
int index = 0) {
|
||||
return tree_map(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
return py::cast(values[index++]);
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
auto validate_argnums_argnames(
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
@@ -437,6 +449,58 @@ auto py_vmap(
|
||||
};
|
||||
}
|
||||
|
||||
std::unordered_map<size_t, py::object>& tree_cache() {
|
||||
// This map is used to Cache the tree structure of the outputs
|
||||
static std::unordered_map<size_t, py::object> tree_cache_;
|
||||
return tree_cache_;
|
||||
}
|
||||
|
||||
struct PyCompiledFun {
|
||||
py::function fun;
|
||||
size_t fun_id;
|
||||
|
||||
PyCompiledFun(const py::function& fun)
|
||||
: fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {}
|
||||
|
||||
PyCompiledFun(const PyCompiledFun&) = delete;
|
||||
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
|
||||
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
|
||||
PyCompiledFun(PyCompiledFun&& other)
|
||||
: fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
|
||||
other.fun_id = 0;
|
||||
};
|
||||
|
||||
py::object operator()(const py::args& args) {
|
||||
auto compile_fun = [this, &args](const std::vector<array>& a) {
|
||||
// Call the python function
|
||||
py::object py_outputs = this->fun(*tree_unflatten(args, a));
|
||||
|
||||
// Flatten the outputs
|
||||
auto outputs = tree_flatten(py_outputs, true);
|
||||
|
||||
py_outputs =
|
||||
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
|
||||
tree_cache().insert({this->fun_id, py_outputs});
|
||||
return outputs;
|
||||
};
|
||||
|
||||
// Inputs must be array or tree of arrays
|
||||
auto inputs = tree_flatten(args, true);
|
||||
|
||||
// Compile and call
|
||||
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
|
||||
|
||||
// Put the outputs back in the container
|
||||
py::object py_outputs = tree_cache().at(fun_id);
|
||||
return tree_unflatten_none(py_outputs, outputs);
|
||||
};
|
||||
|
||||
~PyCompiledFun() {
|
||||
tree_cache().erase(fun_id);
|
||||
detail::compile_erase(fun_id);
|
||||
}
|
||||
};
|
||||
|
||||
void init_transforms(py::module_& m) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
@@ -679,45 +743,6 @@ void init_transforms(py::module_& m) {
|
||||
Returns:
|
||||
function: The vectorized function.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"simplify",
|
||||
[](const py::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args);
|
||||
simplify(arrays);
|
||||
},
|
||||
R"pbdoc(
|
||||
simplify(*args) -> None
|
||||
|
||||
Simplify the graph that computes the arrays.
|
||||
|
||||
Run a few fast graph simplification operations to reuse computation and
|
||||
reduce memory consumption. This function is meant to be run every time
|
||||
so its overhead should be small, approximately 1ms for a graph with a
|
||||
few thousand nodes.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
def foo(x):
|
||||
y = x @ x
|
||||
z = x @ x
|
||||
return y + z
|
||||
|
||||
x = mx.ones((10, 10))
|
||||
y = foo(x)
|
||||
z = foo(x)
|
||||
|
||||
# Computes the matmul twice
|
||||
mx.eval(y)
|
||||
|
||||
# Computes the matmul once
|
||||
mx.simplify(z)
|
||||
mx.eval(z)
|
||||
|
||||
Args:
|
||||
args: Any number of arrays and/or trees of arrays to be simplified.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"export_to_dot",
|
||||
[](py::object file, const py::args& args) {
|
||||
@@ -736,4 +761,46 @@ void init_transforms(py::module_& m) {
|
||||
}
|
||||
},
|
||||
"file"_a);
|
||||
m.def(
|
||||
"compile",
|
||||
[](const py::function& fun) {
|
||||
return py::cpp_function(PyCompiledFun{fun});
|
||||
},
|
||||
"fun"_a,
|
||||
R"pbdoc(
|
||||
compile(fun: function) -> function
|
||||
|
||||
Returns a compiled function which produces the same output as ``fun``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a variable number of :class:`array` or trees of :class:`array`.
|
||||
|
||||
Returns:
|
||||
function: A compiled function which has the same input arguments
|
||||
as ``fun`` and returns the the same output(s).
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"disable_compile",
|
||||
&disable_compile,
|
||||
R"pbdoc(
|
||||
disable_compile() -> None
|
||||
|
||||
Globally disable compilation. Setting the environment variable
|
||||
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"enable_compile",
|
||||
&enable_compile,
|
||||
R"pbdoc(
|
||||
enable_compiler() -> None
|
||||
|
||||
Globally enable compilation. This will override the environment
|
||||
variable ``MLX_DISABLE_COMPILE`` if set.
|
||||
)pbdoc");
|
||||
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = py::module_::import("atexit");
|
||||
atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); }));
|
||||
}
|
||||
|
Reference in New Issue
Block a user