mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Compile front-end (#476)
* fix tests for linux * make a move on compile * basic compile scaffold works * compile binding * clean * fix * fix grad, more tests * basic python tests * fix segfault on python exit * compile works with python closures * fix test * fix python globals bug, and erase * simplify * more cpp tests * bug fix with move function and compile at exit * simplify inputs also * enable and disable compiler * remove simplify * simplify tests use compile now * fix multi-output with compile * clear output tree from cache when function goes out of scope * ../python/src/transforms.cpp * remove closure capture * comments
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
@@ -163,6 +162,19 @@ py::object tree_unflatten(
|
||||
});
|
||||
}
|
||||
|
||||
py::object tree_unflatten_none(
|
||||
py::object tree,
|
||||
const std::vector<array>& values,
|
||||
int index = 0) {
|
||||
return tree_map(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
return py::cast(values[index++]);
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
auto validate_argnums_argnames(
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
@@ -437,6 +449,58 @@ auto py_vmap(
|
||||
};
|
||||
}
|
||||
|
||||
std::unordered_map<size_t, py::object>& tree_cache() {
|
||||
// This map is used to Cache the tree structure of the outputs
|
||||
static std::unordered_map<size_t, py::object> tree_cache_;
|
||||
return tree_cache_;
|
||||
}
|
||||
|
||||
struct PyCompiledFun {
|
||||
py::function fun;
|
||||
size_t fun_id;
|
||||
|
||||
PyCompiledFun(const py::function& fun)
|
||||
: fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {}
|
||||
|
||||
PyCompiledFun(const PyCompiledFun&) = delete;
|
||||
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
|
||||
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
|
||||
PyCompiledFun(PyCompiledFun&& other)
|
||||
: fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
|
||||
other.fun_id = 0;
|
||||
};
|
||||
|
||||
py::object operator()(const py::args& args) {
|
||||
auto compile_fun = [this, &args](const std::vector<array>& a) {
|
||||
// Call the python function
|
||||
py::object py_outputs = this->fun(*tree_unflatten(args, a));
|
||||
|
||||
// Flatten the outputs
|
||||
auto outputs = tree_flatten(py_outputs, true);
|
||||
|
||||
py_outputs =
|
||||
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
|
||||
tree_cache().insert({this->fun_id, py_outputs});
|
||||
return outputs;
|
||||
};
|
||||
|
||||
// Inputs must be array or tree of arrays
|
||||
auto inputs = tree_flatten(args, true);
|
||||
|
||||
// Compile and call
|
||||
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
|
||||
|
||||
// Put the outputs back in the container
|
||||
py::object py_outputs = tree_cache().at(fun_id);
|
||||
return tree_unflatten_none(py_outputs, outputs);
|
||||
};
|
||||
|
||||
~PyCompiledFun() {
|
||||
tree_cache().erase(fun_id);
|
||||
detail::compile_erase(fun_id);
|
||||
}
|
||||
};
|
||||
|
||||
void init_transforms(py::module_& m) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
@@ -679,45 +743,6 @@ void init_transforms(py::module_& m) {
|
||||
Returns:
|
||||
function: The vectorized function.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"simplify",
|
||||
[](const py::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args);
|
||||
simplify(arrays);
|
||||
},
|
||||
R"pbdoc(
|
||||
simplify(*args) -> None
|
||||
|
||||
Simplify the graph that computes the arrays.
|
||||
|
||||
Run a few fast graph simplification operations to reuse computation and
|
||||
reduce memory consumption. This function is meant to be run every time
|
||||
so its overhead should be small, approximately 1ms for a graph with a
|
||||
few thousand nodes.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
def foo(x):
|
||||
y = x @ x
|
||||
z = x @ x
|
||||
return y + z
|
||||
|
||||
x = mx.ones((10, 10))
|
||||
y = foo(x)
|
||||
z = foo(x)
|
||||
|
||||
# Computes the matmul twice
|
||||
mx.eval(y)
|
||||
|
||||
# Computes the matmul once
|
||||
mx.simplify(z)
|
||||
mx.eval(z)
|
||||
|
||||
Args:
|
||||
args: Any number of arrays and/or trees of arrays to be simplified.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"export_to_dot",
|
||||
[](py::object file, const py::args& args) {
|
||||
@@ -736,4 +761,46 @@ void init_transforms(py::module_& m) {
|
||||
}
|
||||
},
|
||||
"file"_a);
|
||||
m.def(
|
||||
"compile",
|
||||
[](const py::function& fun) {
|
||||
return py::cpp_function(PyCompiledFun{fun});
|
||||
},
|
||||
"fun"_a,
|
||||
R"pbdoc(
|
||||
compile(fun: function) -> function
|
||||
|
||||
Returns a compiled function which produces the same output as ``fun``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a variable number of :class:`array` or trees of :class:`array`.
|
||||
|
||||
Returns:
|
||||
function: A compiled function which has the same input arguments
|
||||
as ``fun`` and returns the the same output(s).
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"disable_compile",
|
||||
&disable_compile,
|
||||
R"pbdoc(
|
||||
disable_compile() -> None
|
||||
|
||||
Globally disable compilation. Setting the environment variable
|
||||
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"enable_compile",
|
||||
&enable_compile,
|
||||
R"pbdoc(
|
||||
enable_compiler() -> None
|
||||
|
||||
Globally enable compilation. This will override the environment
|
||||
variable ``MLX_DISABLE_COMPILE`` if set.
|
||||
)pbdoc");
|
||||
|
||||
// Register static Python object cleanup before the interpreter exits
|
||||
auto atexit = py::module_::import("atexit");
|
||||
atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); }));
|
||||
}
|
||||
|
195
python/tests/test_compile.py
Normal file
195
python/tests/test_compile.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import io
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestCompile(mlx_tests.MLXTestCase):
|
||||
def test_simple_compile(self):
|
||||
def fun(x, y):
|
||||
return x + y
|
||||
|
||||
compiled_fn = mx.compile(fun)
|
||||
compiled_fn = mx.compile(fun)
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
out = compiled_fn(x, y)
|
||||
self.assertEqual(out.item(), 2.0)
|
||||
|
||||
# Try again
|
||||
out = compiled_fn(x, y)
|
||||
self.assertEqual(out.item(), 2.0)
|
||||
|
||||
# Change sizes
|
||||
x = mx.array([1.0, 2.0])
|
||||
out = compiled_fn(x, y)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([2.0, 3.0])))
|
||||
|
||||
y = mx.array([1.0, 2.0])
|
||||
out = compiled_fn(x, y)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([2.0, 4.0])))
|
||||
|
||||
# Change types
|
||||
x = mx.array([1, 2], mx.int32)
|
||||
y = mx.array([1, 2], mx.int32)
|
||||
out = compiled_fn(x, y)
|
||||
self.assertEqual(out.dtype, mx.int32)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([2, 4])))
|
||||
|
||||
def test_compile_grad(self):
|
||||
def loss_fn(x):
|
||||
return mx.exp(x).sum()
|
||||
|
||||
grad_fn = mx.grad(loss_fn)
|
||||
|
||||
x = mx.array([0.5, -0.5, 1.2])
|
||||
dfdx = grad_fn(x)
|
||||
compile_grad_fn = mx.compile(grad_fn)
|
||||
c_dfdx = grad_fn(x)
|
||||
|
||||
self.assertTrue(mx.allclose(c_dfdx, dfdx))
|
||||
|
||||
# Run it again without calling compile
|
||||
c_dfdx = compile_grad_fn(x)
|
||||
self.assertTrue(mx.allclose(c_dfdx, dfdx))
|
||||
|
||||
# Run it again with calling compile
|
||||
c_dfdx = mx.compile(grad_fn)(x)
|
||||
self.assertTrue(mx.allclose(c_dfdx, dfdx))
|
||||
|
||||
# Value and grad
|
||||
def loss_fn(x):
|
||||
return mx.exp(x).sum(), mx.sin(x)
|
||||
|
||||
val_and_grad_fn = mx.value_and_grad(loss_fn)
|
||||
(loss, val), dfdx = val_and_grad_fn(x)
|
||||
(c_loss, c_val), c_dfdx = mx.compile(val_and_grad_fn)(x)
|
||||
|
||||
self.assertTrue(mx.allclose(c_dfdx, dfdx))
|
||||
self.assertTrue(mx.allclose(c_loss, loss))
|
||||
self.assertTrue(mx.allclose(c_val, val))
|
||||
|
||||
def test_compile_inputs_with_primitives(self):
|
||||
x = mx.array([1, 2, 3])
|
||||
y = mx.array([1, 2, 3])
|
||||
for _ in range(5):
|
||||
x = x + y
|
||||
y = y + 1
|
||||
|
||||
def fun(x, y):
|
||||
return x * y
|
||||
|
||||
out = fun(x, y)
|
||||
|
||||
x = mx.array([1, 2, 3])
|
||||
y = mx.array([1, 2, 3])
|
||||
for _ in range(5):
|
||||
x = x + y
|
||||
y = y + 1
|
||||
|
||||
c_out = mx.compile(fun)(x, y)
|
||||
self.assertTrue(mx.array_equal(out, c_out))
|
||||
|
||||
# Try again
|
||||
c_out = mx.compile(fun)(x, y)
|
||||
self.assertTrue(mx.array_equal(out, c_out))
|
||||
|
||||
def test_compile_with_closure(self):
|
||||
x = mx.array(1)
|
||||
|
||||
def closure(y):
|
||||
return x + y
|
||||
|
||||
compiled = mx.compile(closure)
|
||||
out = compiled(mx.array(1))
|
||||
self.assertEqual(out.item(), 2)
|
||||
|
||||
# Try again
|
||||
out = compiled(mx.array(1))
|
||||
self.assertEqual(out.item(), 2)
|
||||
|
||||
# Change the shape of the enclosed variable
|
||||
x = mx.array([1, 2])
|
||||
out = compiled(mx.array(1))
|
||||
|
||||
# We still get the original input (closures are not updated)
|
||||
self.assertEqual(out.item(), 2)
|
||||
|
||||
# Try with a tree of enclosed variables
|
||||
x = {"a": mx.array(1), "b": mx.array(2)}
|
||||
|
||||
def closure(y):
|
||||
return x["a"] + y + x["b"]
|
||||
|
||||
compiled = mx.compile(closure)
|
||||
out = compiled(mx.array(1))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
# Change the shape of one input
|
||||
x["a"] = mx.array([4, 5])
|
||||
out = compiled(mx.array(1))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
x["b"] = mx.array([-6, -8])
|
||||
out = compiled(mx.array(1))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
# Enclosed variable is not evaluated yet
|
||||
x = mx.array(1)
|
||||
x = x + x
|
||||
|
||||
def closure(y):
|
||||
return x + y
|
||||
|
||||
compiled = mx.compile(closure)
|
||||
out = compiled(mx.array(2))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
# And again
|
||||
out = compiled(mx.array(2))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
def test_function_creates_array(self):
|
||||
def fun(x):
|
||||
return x + mx.array(1)
|
||||
|
||||
cfun = mx.compile(fun)
|
||||
out = cfun(mx.array(3))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
# And again
|
||||
out = cfun(mx.array(3))
|
||||
self.assertEqual(out.item(), 4)
|
||||
|
||||
def test_enable_disable(self):
|
||||
def fun(x):
|
||||
y = x + 1
|
||||
z = x + 1
|
||||
return y + z
|
||||
|
||||
def count_prims(outputs):
|
||||
buf = io.StringIO()
|
||||
mx.export_to_dot(buf, outputs)
|
||||
buf.seek(0)
|
||||
return len([l for l in buf.read().split() if "label" in l])
|
||||
|
||||
x = mx.array(1.0)
|
||||
cfun = mx.compile(fun)
|
||||
n_compiled = count_prims(cfun(x))
|
||||
|
||||
# Check disabled
|
||||
mx.disable_compile()
|
||||
n_uncompiled = count_prims(cfun(x))
|
||||
self.assertTrue(n_compiled < n_uncompiled)
|
||||
|
||||
# Check renabled
|
||||
mx.enable_compile()
|
||||
n_enable_compiled = count_prims(cfun(x))
|
||||
self.assertEqual(n_compiled, n_enable_compiled)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user