mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-04 02:28:13 +08:00
compile binding
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
#include <iostream> // TODO
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
@@ -437,6 +438,34 @@ auto py_vmap(
|
||||
};
|
||||
}
|
||||
|
||||
auto py_compile(const py::function& fun) {
|
||||
return [fun](const py::args& args) {
|
||||
// Inputs must be array or tree of arrays
|
||||
auto inputs = tree_flatten(args, true);
|
||||
|
||||
// py_value_out will hold the output of the python function in order to be
|
||||
// able to reconstruct the python tree of extra return values
|
||||
py::object py_outputs;
|
||||
|
||||
auto compile_fun =
|
||||
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
|
||||
// Call the python function
|
||||
py_outputs = fun(*tree_unflatten(args, a));
|
||||
|
||||
// Flatten the outputs
|
||||
return tree_flatten(py_outputs, true);
|
||||
};
|
||||
|
||||
// Compile and call
|
||||
// TODO, awni, I think this cast is ok??
|
||||
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
|
||||
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
|
||||
|
||||
// Put the outputs back in the container
|
||||
return tree_unflatten(py_outputs, outputs);
|
||||
};
|
||||
}
|
||||
|
||||
void init_transforms(py::module_& m) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
@@ -736,4 +765,22 @@ void init_transforms(py::module_& m) {
|
||||
}
|
||||
},
|
||||
"file"_a);
|
||||
m.def(
|
||||
"compile",
|
||||
[](const py::function& fun) { return py::cpp_function(py_compile(fun)); },
|
||||
"fun"_a,
|
||||
R"pbdoc(
|
||||
compile(fun: function) -> function
|
||||
|
||||
Returns a compiled function which produces the same output as ``fun``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a variable number of :class:`array` or trees of :class:`array`.
|
||||
|
||||
Returns:
|
||||
function: A compiled function which has the same input arguments
|
||||
as ``fun`` and returns the the same output(s).
|
||||
)pbdoc");
|
||||
}
|
||||
|
||||
22
python/tests/test_compile.py
Normal file
22
python/tests/test_compile.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestCompile(mlx_tests.MLXTestCase):
|
||||
def test_simple_compile(self):
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
compiled_fn = mx.compile(fun)
|
||||
compiled_fn = mx.compile(fun)
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
# out = compiled_fn(x, y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user