Compile front-end (#476)

* fix tests for linux

* make a move on compile

* basic compile scaffold works

* compile binding

* clean

* fix

* fix grad, more tests

* basic python tests

* fix segfault on python exit

* compile works with python closures

* fix test

* fix python globals bug, and erase

* simplify

* more cpp tests

* bug fix with move function and compile at exit

* simplify inputs also

* enable and disable compiler

* remove simplify

* simplify tests use compile now

* fix multi-output with compile

* clear output tree from cache when function goes out of scope

* ../python/src/transforms.cpp

* remove closure capture

* comments
This commit is contained in:
Awni Hannun
2024-01-26 13:45:30 -08:00
committed by GitHub
parent 874b739f3c
commit 8fa6b322b9
13 changed files with 1029 additions and 297 deletions

View File

@@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
@@ -163,6 +162,19 @@ py::object tree_unflatten(
});
}
py::object tree_unflatten_none(
py::object tree,
const std::vector<array>& values,
int index = 0) {
return tree_map(tree, [&](py::handle obj) {
if (py::isinstance<py::none>(obj)) {
return py::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
}
});
}
auto validate_argnums_argnames(
const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) {
@@ -437,6 +449,58 @@ auto py_vmap(
};
}
std::unordered_map<size_t, py::object>& tree_cache() {
// This map is used to Cache the tree structure of the outputs
static std::unordered_map<size_t, py::object> tree_cache_;
return tree_cache_;
}
struct PyCompiledFun {
py::function fun;
size_t fun_id;
PyCompiledFun(const py::function& fun)
: fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {}
PyCompiledFun(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
PyCompiledFun(PyCompiledFun&& other)
: fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
other.fun_id = 0;
};
py::object operator()(const py::args& args) {
auto compile_fun = [this, &args](const std::vector<array>& a) {
// Call the python function
py::object py_outputs = this->fun(*tree_unflatten(args, a));
// Flatten the outputs
auto outputs = tree_flatten(py_outputs, true);
py_outputs =
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
tree_cache().insert({this->fun_id, py_outputs});
return outputs;
};
// Inputs must be array or tree of arrays
auto inputs = tree_flatten(args, true);
// Compile and call
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
// Put the outputs back in the container
py::object py_outputs = tree_cache().at(fun_id);
return tree_unflatten_none(py_outputs, outputs);
};
~PyCompiledFun() {
tree_cache().erase(fun_id);
detail::compile_erase(fun_id);
}
};
void init_transforms(py::module_& m) {
py::options options;
options.disable_function_signatures();
@@ -679,45 +743,6 @@ void init_transforms(py::module_& m) {
Returns:
function: The vectorized function.
)pbdoc");
m.def(
"simplify",
[](const py::args& args) {
std::vector<array> arrays = tree_flatten(args);
simplify(arrays);
},
R"pbdoc(
simplify(*args) -> None
Simplify the graph that computes the arrays.
Run a few fast graph simplification operations to reuse computation and
reduce memory consumption. This function is meant to be run every time
so its overhead should be small, approximately 1ms for a graph with a
few thousand nodes.
.. code-block:: python
import mlx.core as mx
def foo(x):
y = x @ x
z = x @ x
return y + z
x = mx.ones((10, 10))
y = foo(x)
z = foo(x)
# Computes the matmul twice
mx.eval(y)
# Computes the matmul once
mx.simplify(z)
mx.eval(z)
Args:
args: Any number of arrays and/or trees of arrays to be simplified.
)pbdoc");
m.def(
"export_to_dot",
[](py::object file, const py::args& args) {
@@ -736,4 +761,46 @@ void init_transforms(py::module_& m) {
}
},
"file"_a);
m.def(
"compile",
[](const py::function& fun) {
return py::cpp_function(PyCompiledFun{fun});
},
"fun"_a,
R"pbdoc(
compile(fun: function) -> function
Returns a compiled function which produces the same output as ``fun``.
Args:
fun (function): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a variable number of :class:`array` or trees of :class:`array`.
Returns:
function: A compiled function which has the same input arguments
as ``fun`` and returns the the same output(s).
)pbdoc");
m.def(
"disable_compile",
&disable_compile,
R"pbdoc(
disable_compile() -> None
Globally disable compilation. Setting the environment variable
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
)pbdoc");
m.def(
"enable_compile",
&enable_compile,
R"pbdoc(
enable_compiler() -> None
Globally enable compilation. This will override the environment
variable ``MLX_DISABLE_COMPILE`` if set.
)pbdoc");
// Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit");
atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); }));
}