compile binding

This commit is contained in:
Awni Hannun
2024-01-14 14:26:53 -08:00
parent 21062680d5
commit cd1e5b25cc
5 changed files with 285 additions and 40 deletions

View File

@@ -1,4 +1,5 @@
// Copyright © 2023 Apple Inc.
#include <iostream> // TODO
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
@@ -437,6 +438,34 @@ auto py_vmap(
};
}
auto py_compile(const py::function& fun) {
return [fun](const py::args& args) {
// Inputs must be array or tree of arrays
auto inputs = tree_flatten(args, true);
// py_value_out will hold the output of the python function in order to be
// able to reconstruct the python tree of extra return values
py::object py_outputs;
auto compile_fun =
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
// Call the python function
py_outputs = fun(*tree_unflatten(args, a));
// Flatten the outputs
return tree_flatten(py_outputs, true);
};
// Compile and call
// TODO, awni, I think this cast is ok??
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
// Put the outputs back in the container
return tree_unflatten(py_outputs, outputs);
};
}
void init_transforms(py::module_& m) {
py::options options;
options.disable_function_signatures();
@@ -736,4 +765,22 @@ void init_transforms(py::module_& m) {
}
},
"file"_a);
m.def(
"compile",
[](const py::function& fun) { return py::cpp_function(py_compile(fun)); },
"fun"_a,
R"pbdoc(
compile(fun: function) -> function
Returns a compiled function which produces the same output as ``fun``.
Args:
fun (function): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a variable number of :class:`array` or trees of :class:`array`.
Returns:
function: A compiled function which has the same input arguments
as ``fun`` and returns the the same output(s).
)pbdoc");
}

View File

@@ -0,0 +1,22 @@
# Copyright © 2023-2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_tests
class TestCompile(mlx_tests.MLXTestCase):
def test_simple_compile(self):
def fun(x, y):
return x + y
compiled_fn = mx.compile(fun)
compiled_fn = mx.compile(fun)
x = mx.array(1.0)
y = mx.array(1.0)
# out = compiled_fn(x, y)
if __name__ == "__main__":
unittest.main()