From 4f50935c2ccd508d0151ed431e1a34e5d4df8a3f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 15 Jan 2024 11:26:28 -0800 Subject: [PATCH] compile works with python closures --- python/src/transforms.cpp | 29 ++++++++++--- python/tests/test_compile.py | 79 ++++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 5 deletions(-) diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index bef7f8993..ac1957c78 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -457,14 +457,10 @@ std::unordered_map& tree_cache() { auto py_compile(const py::function& fun) { return [fun](const py::args& args) { - // Inputs must be array or tree of arrays - auto inputs = tree_flatten(args, true); - // TODO, awni, I think this cast is ok?? size_t fun_id = reinterpret_cast(fun.ptr()); - auto compile_fun = [fun_id, &fun, &args, &inputs]( - const std::vector& a) { + auto compile_fun = [fun_id, &fun, &args](const std::vector& a) { // Call the python function py::object py_outputs = fun(*tree_unflatten(args, a)); @@ -477,6 +473,29 @@ auto py_compile(const py::function& fun) { return outputs; }; + // Inputs must be array or tree of arrays + auto inputs = tree_flatten(args, true); + + // Get globally enclosed arrays so we don't compile through them + auto global_inputs = tree_flatten(py::getattr(fun, "__globals__"), false); + std::move( + std::begin(global_inputs), + std::end(global_inputs), + std::back_inserter(inputs)); + + // Get locally enclosed arrays so we don't compile through them + auto closures = py::getattr(fun, "__closure__"); + if (py::isinstance(closures)) { + for (auto& closure : closures) { + auto enclosed_inputs = + tree_flatten(py::getattr(closure, "cell_contents"), false); + std::move( + std::begin(enclosed_inputs), + std::end(enclosed_inputs), + std::back_inserter(inputs)); + } + } + // Compile and call auto outputs = detail::compile(compile_fun, fun_id)(inputs); diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index a8bbb834b..c9961ea4d 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -22,6 +22,22 @@ class TestCompile(mlx_tests.MLXTestCase): out = compiled_fn(x, y) self.assertEqual(out.item(), 2.0) + # Change sizes + x = mx.array([1.0, 2.0]) + out = compiled_fn(x, y) + self.assertTrue(mx.array_equal(out, mx.array([2.0, 3.0]))) + + y = mx.array([1.0, 2.0]) + out = compiled_fn(x, y) + self.assertTrue(mx.array_equal(out, mx.array([2.0, 4.0]))) + + # Change types + x = mx.array([1, 2], mx.int32) + y = mx.array([1, 2], mx.int32) + out = compiled_fn(x, y) + self.assertEqual(out.dtype, mx.int32) + self.assertTrue(mx.array_equal(out, mx.array([2, 4]))) + def test_compile_grad(self): def loss_fn(x): return mx.exp(x).sum() @@ -55,6 +71,69 @@ class TestCompile(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(c_loss, loss)) self.assertTrue(mx.allclose(c_val, val)) + def test_compile_inputs_with_primitives(self): + x = mx.array([1, 2, 3]) + y = mx.array([1, 2, 3]) + for _ in range(5): + x = x + y + y = y + 1 + + def fun(x, y): + return x * y + + out = fun(x, y) + + x = mx.array([1, 2, 3]) + y = mx.array([1, 2, 3]) + for _ in range(5): + x = x + y + y = y + 1 + + c_out = mx.compile(fun)(x, y) + self.assertTrue(mx.array_equal(out, c_out)) + + # Try again + c_out = mx.compile(fun)(x, y) + self.assertTrue(mx.array_equal(out, c_out)) + + def test_compile_with_closure(self): + x = mx.array(1) + + def closure(y): + return x + y + + compiled = mx.compile(closure) + out = compiled(mx.array(1)) + self.assertTrue(out.item(), 2) + + # Try again + out = compiled(mx.array(1)) + self.assertTrue(out.item(), 2) + + # Change the shape of the enclosed variable + x = mx.array([1, 2]) + out = compiled(mx.array(1)) + self.assertTrue(mx.array_equal(out, mx.array([2, 3]))) + + # Try with a tree of enclosed variables + x = {"a": mx.array(1), "b": mx.array(2)} + + def closure(y): + return x["a"] + y + x["b"] + + compiled = mx.compile(closure) + out = compiled(mx.array(1)) + self.assertEqual(out.item(), 4) + + # Change the shape of one input + x["a"] = mx.array([4, 5]) + out = compiled(mx.array(1)) + self.assertTrue(mx.array_equal(out, mx.array([7, 8]))) + + x["b"] = mx.array([-6, -8]) + out = compiled(mx.array(1)) + self.assertTrue(mx.array_equal(out, mx.array([-1, -2]))) + if __name__ == "__main__": unittest.main()