compile works with python closures

This commit is contained in:
Awni Hannun 2024-01-15 11:26:28 -08:00
parent 966b7faef4
commit 4f50935c2c
2 changed files with 103 additions and 5 deletions

View File

@ -457,14 +457,10 @@ std::unordered_map<size_t, py::object>& tree_cache() {
auto py_compile(const py::function& fun) {
return [fun](const py::args& args) {
// Inputs must be array or tree of arrays
auto inputs = tree_flatten(args, true);
// TODO, awni, I think this cast is ok??
size_t fun_id = reinterpret_cast<size_t>(fun.ptr());
auto compile_fun = [fun_id, &fun, &args, &inputs](
const std::vector<array>& a) {
auto compile_fun = [fun_id, &fun, &args](const std::vector<array>& a) {
// Call the python function
py::object py_outputs = fun(*tree_unflatten(args, a));
@ -477,6 +473,29 @@ auto py_compile(const py::function& fun) {
return outputs;
};
// Inputs must be array or tree of arrays
auto inputs = tree_flatten(args, true);
// Get globally enclosed arrays so we don't compile through them
auto global_inputs = tree_flatten(py::getattr(fun, "__globals__"), false);
std::move(
std::begin(global_inputs),
std::end(global_inputs),
std::back_inserter(inputs));
// Get locally enclosed arrays so we don't compile through them
auto closures = py::getattr(fun, "__closure__");
if (py::isinstance<py::tuple>(closures)) {
for (auto& closure : closures) {
auto enclosed_inputs =
tree_flatten(py::getattr(closure, "cell_contents"), false);
std::move(
std::begin(enclosed_inputs),
std::end(enclosed_inputs),
std::back_inserter(inputs));
}
}
// Compile and call
auto outputs = detail::compile(compile_fun, fun_id)(inputs);

View File

@ -22,6 +22,22 @@ class TestCompile(mlx_tests.MLXTestCase):
out = compiled_fn(x, y)
self.assertEqual(out.item(), 2.0)
# Change sizes
x = mx.array([1.0, 2.0])
out = compiled_fn(x, y)
self.assertTrue(mx.array_equal(out, mx.array([2.0, 3.0])))
y = mx.array([1.0, 2.0])
out = compiled_fn(x, y)
self.assertTrue(mx.array_equal(out, mx.array([2.0, 4.0])))
# Change types
x = mx.array([1, 2], mx.int32)
y = mx.array([1, 2], mx.int32)
out = compiled_fn(x, y)
self.assertEqual(out.dtype, mx.int32)
self.assertTrue(mx.array_equal(out, mx.array([2, 4])))
def test_compile_grad(self):
def loss_fn(x):
return mx.exp(x).sum()
@ -55,6 +71,69 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(c_loss, loss))
self.assertTrue(mx.allclose(c_val, val))
def test_compile_inputs_with_primitives(self):
x = mx.array([1, 2, 3])
y = mx.array([1, 2, 3])
for _ in range(5):
x = x + y
y = y + 1
def fun(x, y):
return x * y
out = fun(x, y)
x = mx.array([1, 2, 3])
y = mx.array([1, 2, 3])
for _ in range(5):
x = x + y
y = y + 1
c_out = mx.compile(fun)(x, y)
self.assertTrue(mx.array_equal(out, c_out))
# Try again
c_out = mx.compile(fun)(x, y)
self.assertTrue(mx.array_equal(out, c_out))
def test_compile_with_closure(self):
x = mx.array(1)
def closure(y):
return x + y
compiled = mx.compile(closure)
out = compiled(mx.array(1))
self.assertTrue(out.item(), 2)
# Try again
out = compiled(mx.array(1))
self.assertTrue(out.item(), 2)
# Change the shape of the enclosed variable
x = mx.array([1, 2])
out = compiled(mx.array(1))
self.assertTrue(mx.array_equal(out, mx.array([2, 3])))
# Try with a tree of enclosed variables
x = {"a": mx.array(1), "b": mx.array(2)}
def closure(y):
return x["a"] + y + x["b"]
compiled = mx.compile(closure)
out = compiled(mx.array(1))
self.assertEqual(out.item(), 4)
# Change the shape of one input
x["a"] = mx.array([4, 5])
out = compiled(mx.array(1))
self.assertTrue(mx.array_equal(out, mx.array([7, 8])))
x["b"] = mx.array([-6, -8])
out = compiled(mx.array(1))
self.assertTrue(mx.array_equal(out, mx.array([-1, -2])))
if __name__ == "__main__":
unittest.main()