Compile front-end (#476)

* fix tests for linux

* make a move on compile

* basic compile scaffold works

* compile binding

* clean

* fix

* fix grad, more tests

* basic python tests

* fix segfault on python exit

* compile works with python closures

* fix test

* fix python globals bug, and erase

* simplify

* more cpp tests

* bug fix with move function and compile at exit

* simplify inputs also

* enable and disable compiler

* remove simplify

* simplify tests use compile now

* fix multi-output with compile

* clear output tree from cache when function goes out of scope

* ../python/src/transforms.cpp

* remove closure capture

* comments
This commit is contained in:
Awni Hannun
2024-01-26 13:45:30 -08:00
committed by GitHub
parent 874b739f3c
commit 8fa6b322b9
13 changed files with 1029 additions and 297 deletions

View File

@@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
@@ -163,6 +162,19 @@ py::object tree_unflatten(
});
}
py::object tree_unflatten_none(
py::object tree,
const std::vector<array>& values,
int index = 0) {
return tree_map(tree, [&](py::handle obj) {
if (py::isinstance<py::none>(obj)) {
return py::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
}
});
}
auto validate_argnums_argnames(
const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) {
@@ -437,6 +449,58 @@ auto py_vmap(
};
}
std::unordered_map<size_t, py::object>& tree_cache() {
// This map is used to Cache the tree structure of the outputs
static std::unordered_map<size_t, py::object> tree_cache_;
return tree_cache_;
}
struct PyCompiledFun {
py::function fun;
size_t fun_id;
PyCompiledFun(const py::function& fun)
: fun(fun), fun_id(reinterpret_cast<size_t>(fun.ptr())) {}
PyCompiledFun(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(const PyCompiledFun&) = delete;
PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
PyCompiledFun(PyCompiledFun&& other)
: fun(std::move(other.fun)), fun_id(reinterpret_cast<size_t>(fun.ptr())) {
other.fun_id = 0;
};
py::object operator()(const py::args& args) {
auto compile_fun = [this, &args](const std::vector<array>& a) {
// Call the python function
py::object py_outputs = this->fun(*tree_unflatten(args, a));
// Flatten the outputs
auto outputs = tree_flatten(py_outputs, true);
py_outputs =
tree_map(py_outputs, [](const py::handle& x) { return py::none(); });
tree_cache().insert({this->fun_id, py_outputs});
return outputs;
};
// Inputs must be array or tree of arrays
auto inputs = tree_flatten(args, true);
// Compile and call
auto outputs = detail::compile(compile_fun, fun_id)(inputs);
// Put the outputs back in the container
py::object py_outputs = tree_cache().at(fun_id);
return tree_unflatten_none(py_outputs, outputs);
};
~PyCompiledFun() {
tree_cache().erase(fun_id);
detail::compile_erase(fun_id);
}
};
void init_transforms(py::module_& m) {
py::options options;
options.disable_function_signatures();
@@ -679,45 +743,6 @@ void init_transforms(py::module_& m) {
Returns:
function: The vectorized function.
)pbdoc");
m.def(
"simplify",
[](const py::args& args) {
std::vector<array> arrays = tree_flatten(args);
simplify(arrays);
},
R"pbdoc(
simplify(*args) -> None
Simplify the graph that computes the arrays.
Run a few fast graph simplification operations to reuse computation and
reduce memory consumption. This function is meant to be run every time
so its overhead should be small, approximately 1ms for a graph with a
few thousand nodes.
.. code-block:: python
import mlx.core as mx
def foo(x):
y = x @ x
z = x @ x
return y + z
x = mx.ones((10, 10))
y = foo(x)
z = foo(x)
# Computes the matmul twice
mx.eval(y)
# Computes the matmul once
mx.simplify(z)
mx.eval(z)
Args:
args: Any number of arrays and/or trees of arrays to be simplified.
)pbdoc");
m.def(
"export_to_dot",
[](py::object file, const py::args& args) {
@@ -736,4 +761,46 @@ void init_transforms(py::module_& m) {
}
},
"file"_a);
m.def(
"compile",
[](const py::function& fun) {
return py::cpp_function(PyCompiledFun{fun});
},
"fun"_a,
R"pbdoc(
compile(fun: function) -> function
Returns a compiled function which produces the same output as ``fun``.
Args:
fun (function): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a variable number of :class:`array` or trees of :class:`array`.
Returns:
function: A compiled function which has the same input arguments
as ``fun`` and returns the the same output(s).
)pbdoc");
m.def(
"disable_compile",
&disable_compile,
R"pbdoc(
disable_compile() -> None
Globally disable compilation. Setting the environment variable
``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
)pbdoc");
m.def(
"enable_compile",
&enable_compile,
R"pbdoc(
enable_compiler() -> None
Globally enable compilation. This will override the environment
variable ``MLX_DISABLE_COMPILE`` if set.
)pbdoc");
// Register static Python object cleanup before the interpreter exits
auto atexit = py::module_::import("atexit");
atexit.attr("register")(py::cpp_function([]() { tree_cache().clear(); }));
}

View File

@@ -0,0 +1,195 @@
# Copyright © 2023-2024 Apple Inc.
import io
import unittest
import mlx.core as mx
import mlx_tests
class TestCompile(mlx_tests.MLXTestCase):
def test_simple_compile(self):
def fun(x, y):
return x + y
compiled_fn = mx.compile(fun)
compiled_fn = mx.compile(fun)
x = mx.array(1.0)
y = mx.array(1.0)
out = compiled_fn(x, y)
self.assertEqual(out.item(), 2.0)
# Try again
out = compiled_fn(x, y)
self.assertEqual(out.item(), 2.0)
# Change sizes
x = mx.array([1.0, 2.0])
out = compiled_fn(x, y)
self.assertTrue(mx.array_equal(out, mx.array([2.0, 3.0])))
y = mx.array([1.0, 2.0])
out = compiled_fn(x, y)
self.assertTrue(mx.array_equal(out, mx.array([2.0, 4.0])))
# Change types
x = mx.array([1, 2], mx.int32)
y = mx.array([1, 2], mx.int32)
out = compiled_fn(x, y)
self.assertEqual(out.dtype, mx.int32)
self.assertTrue(mx.array_equal(out, mx.array([2, 4])))
def test_compile_grad(self):
def loss_fn(x):
return mx.exp(x).sum()
grad_fn = mx.grad(loss_fn)
x = mx.array([0.5, -0.5, 1.2])
dfdx = grad_fn(x)
compile_grad_fn = mx.compile(grad_fn)
c_dfdx = grad_fn(x)
self.assertTrue(mx.allclose(c_dfdx, dfdx))
# Run it again without calling compile
c_dfdx = compile_grad_fn(x)
self.assertTrue(mx.allclose(c_dfdx, dfdx))
# Run it again with calling compile
c_dfdx = mx.compile(grad_fn)(x)
self.assertTrue(mx.allclose(c_dfdx, dfdx))
# Value and grad
def loss_fn(x):
return mx.exp(x).sum(), mx.sin(x)
val_and_grad_fn = mx.value_and_grad(loss_fn)
(loss, val), dfdx = val_and_grad_fn(x)
(c_loss, c_val), c_dfdx = mx.compile(val_and_grad_fn)(x)
self.assertTrue(mx.allclose(c_dfdx, dfdx))
self.assertTrue(mx.allclose(c_loss, loss))
self.assertTrue(mx.allclose(c_val, val))
def test_compile_inputs_with_primitives(self):
x = mx.array([1, 2, 3])
y = mx.array([1, 2, 3])
for _ in range(5):
x = x + y
y = y + 1
def fun(x, y):
return x * y
out = fun(x, y)
x = mx.array([1, 2, 3])
y = mx.array([1, 2, 3])
for _ in range(5):
x = x + y
y = y + 1
c_out = mx.compile(fun)(x, y)
self.assertTrue(mx.array_equal(out, c_out))
# Try again
c_out = mx.compile(fun)(x, y)
self.assertTrue(mx.array_equal(out, c_out))
def test_compile_with_closure(self):
x = mx.array(1)
def closure(y):
return x + y
compiled = mx.compile(closure)
out = compiled(mx.array(1))
self.assertEqual(out.item(), 2)
# Try again
out = compiled(mx.array(1))
self.assertEqual(out.item(), 2)
# Change the shape of the enclosed variable
x = mx.array([1, 2])
out = compiled(mx.array(1))
# We still get the original input (closures are not updated)
self.assertEqual(out.item(), 2)
# Try with a tree of enclosed variables
x = {"a": mx.array(1), "b": mx.array(2)}
def closure(y):
return x["a"] + y + x["b"]
compiled = mx.compile(closure)
out = compiled(mx.array(1))
self.assertEqual(out.item(), 4)
# Change the shape of one input
x["a"] = mx.array([4, 5])
out = compiled(mx.array(1))
self.assertEqual(out.item(), 4)
x["b"] = mx.array([-6, -8])
out = compiled(mx.array(1))
self.assertEqual(out.item(), 4)
# Enclosed variable is not evaluated yet
x = mx.array(1)
x = x + x
def closure(y):
return x + y
compiled = mx.compile(closure)
out = compiled(mx.array(2))
self.assertEqual(out.item(), 4)
# And again
out = compiled(mx.array(2))
self.assertEqual(out.item(), 4)
def test_function_creates_array(self):
def fun(x):
return x + mx.array(1)
cfun = mx.compile(fun)
out = cfun(mx.array(3))
self.assertEqual(out.item(), 4)
# And again
out = cfun(mx.array(3))
self.assertEqual(out.item(), 4)
def test_enable_disable(self):
def fun(x):
y = x + 1
z = x + 1
return y + z
def count_prims(outputs):
buf = io.StringIO()
mx.export_to_dot(buf, outputs)
buf.seek(0)
return len([l for l in buf.read().split() if "label" in l])
x = mx.array(1.0)
cfun = mx.compile(fun)
n_compiled = count_prims(cfun(x))
# Check disabled
mx.disable_compile()
n_uncompiled = count_prims(cfun(x))
self.assertTrue(n_compiled < n_uncompiled)
# Check renabled
mx.enable_compile()
n_enable_compiled = count_prims(cfun(x))
self.assertEqual(n_compiled, n_enable_compiled)
if __name__ == "__main__":
unittest.main()