diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index ab7bbc385..91d9f8803 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -1,5 +1,4 @@ // Copyright © 2023 Apple Inc. - #include #include #include @@ -163,6 +162,19 @@ py::object tree_unflatten( }); } +py::object tree_unflatten_none( + py::object tree, + const std::vector& values, + int index = 0) { + return tree_map(tree, [&](py::handle obj) { + if (py::isinstance(obj)) { + return py::cast(values[index++]); + } else { + return py::cast(obj); + } + }); +} + auto validate_argnums_argnames( const std::optional& argnums, const StrOrVec& argnames) { @@ -438,30 +450,36 @@ auto py_vmap( } auto py_compile(const py::function& fun) { + // This map is used to Cache the tree structure of the outputs + static std::unordered_map tree_cache; + return [fun](const py::args& args) { // Inputs must be array or tree of arrays auto inputs = tree_flatten(args, true); - // py_value_out will hold the output of the python function in order to be - // able to reconstruct the python tree of extra return values - py::object py_outputs; - - auto compile_fun = - [&fun, &args, &inputs, &py_outputs](const std::vector& a) { - // Call the python function - py_outputs = fun(*tree_unflatten(args, a)); - - // Flatten the outputs - return tree_flatten(py_outputs, true); - }; - - // Compile and call // TODO, awni, I think this cast is ok?? size_t fun_id = reinterpret_cast(fun.ptr()); + + auto compile_fun = [fun_id, &fun, &args, &inputs]( + const std::vector& a) { + // Call the python function + py::object py_outputs = fun(*tree_unflatten(args, a)); + + // Flatten the outputs + auto outputs = tree_flatten(py_outputs, true); + + py_outputs = + tree_map(py_outputs, [](const py::handle& x) { return py::none(); }); + tree_cache.insert({fun_id, py_outputs}); + return outputs; + }; + + // Compile and call auto outputs = detail::compile(compile_fun, fun_id)(inputs); // Put the outputs back in the container - return tree_unflatten(py_outputs, outputs); + py::object py_outputs = tree_cache.at(fun_id); + return tree_unflatten_none(py_outputs, outputs); }; } diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 39e075d72..f9b36414b 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -15,7 +15,12 @@ class TestCompile(mlx_tests.MLXTestCase): compiled_fn = mx.compile(fun) x = mx.array(1.0) y = mx.array(1.0) - # out = compiled_fn(x, y) + out = compiled_fn(x, y) + self.assertEqual(out.item(), 2.0) + + # Try again + out = compiled_fn(x, y) + self.assertEqual(out.item(), 2.0) if __name__ == "__main__": diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index dc8c3e980..94d108a4d 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -1,6 +1,8 @@ // Copyright © 2023 Apple Inc. +#include // TODO #include "doctest/doctest.h" +#include "mlx/utils.h" // TODO #include "mlx/mlx.h" @@ -33,17 +35,50 @@ TEST_CASE("test simple compile") { CHECK(array_equal(out, array({3.0f, 4.0f})).item()); } -std::vector fun1(const std::vector& inputs) { +std::vector grad_fun(const std::vector& inputs) { auto loss = [](std::vector ins) { return exp(ins[0] + ins[1]); }; - return grad(loss)(inputs); + return grad(loss, {0, 1})(inputs); } TEST_CASE("test compile with grad") { auto x = array(1.0f); auto y = array(1.0f); - auto grads_expected = fun1({x, y}); - auto grads_compile = compile(fun1)({x, y}); + auto grads_expected = grad_fun({x, y}); + auto grads_compile = compile(grad_fun)({x, y}); CHECK_EQ(grads_compile[0].item(), grads_expected[0].item()); + CHECK_EQ(grads_compile[1].item(), grads_expected[1].item()); } +TEST_CASE("test compile inputs with primitive") { + auto [k1, k2] = random::split(random::key(0)); + auto x = random::uniform({5, 5}, k1); + auto y = random::uniform({5, 5}, k2); + auto expected = simple_fun({x, y})[0]; + + x = random::uniform({5, 5}, k1); + y = random::uniform({5, 5}, k2); + auto out = compile(simple_fun)({x, y})[0]; + CHECK(array_equal(expected, out).item()); + + // Same thing twice + out = compile(simple_fun)({x, y})[0]; + CHECK(array_equal(expected, out).item()); +} + +/*std::vector bigger_fun(const std::vector& inputs) { + auto x = inputs[1]; + for (int i = 1; i < inputs.size(); ++i) { + w = inputs[i] + x = maximum(matmul(x, w), 0); + } + return take(x, array(3)) - logsumexp(x); +} + +TEST_CASE("test bigger graph") { + std::vector inputs; + inputs.push_back( + for (int + for +}*/ + TEST_CASE("test nested compile") {}