diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 94d108a4d..b012fff36 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -1,8 +1,6 @@ // Copyright © 2023 Apple Inc. -#include // TODO #include "doctest/doctest.h" -#include "mlx/utils.h" // TODO #include "mlx/mlx.h" @@ -65,20 +63,36 @@ TEST_CASE("test compile inputs with primitive") { CHECK(array_equal(expected, out).item()); } -/*std::vector bigger_fun(const std::vector& inputs) { - auto x = inputs[1]; - for (int i = 1; i < inputs.size(); ++i) { - w = inputs[i] - x = maximum(matmul(x, w), 0); - } - return take(x, array(3)) - logsumexp(x); +std::vector fun_creats_array(const std::vector& inputs) { + return {inputs[0] + array(1.0f)}; } -TEST_CASE("test bigger graph") { - std::vector inputs; - inputs.push_back( - for (int - for -}*/ +TEST_CASE("test compile with created array") { + auto cfun = compile(fun_creats_array); + auto out = cfun({array(2.0f)}); + CHECK_EQ(out[0].item(), 3.0f); -TEST_CASE("test nested compile") {} + // Try again + out = cfun({array(2.0f)}); + CHECK_EQ(out[0].item(), 3.0f); +} + +std::vector inner_fun(const std::vector& inputs) { + return {array(2) * inputs[0]}; +} + +std::vector outer_fun(const std::vector& inputs) { + auto x = inputs[0] + inputs[1]; + auto y = compile(inner_fun)({x})[0]; + return {x + y}; +} + +TEST_CASE("test nested compile") { + auto cfun = compile(outer_fun); + auto out = cfun({array(1), array(2)})[0]; + CHECK_EQ(out.item(), 9); + + // Try again + out = cfun({array(1), array(2)})[0]; + CHECK_EQ(out.item(), 9); +}