From 9739c72781c4d2bf6afb49e4077761171305222f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 14 Jan 2024 19:42:36 -0800 Subject: [PATCH] fix grad, more tests --- mlx/compile.cpp | 39 +++++++++++++++------------------------ tests/compile_tests.cpp | 35 ++++++++++++++++++++++++++++++++--- 2 files changed, 47 insertions(+), 27 deletions(-) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index b6bd5cc1a..cfdd775d4 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -107,10 +107,8 @@ std::pair, ParentsMap> compile_dfs( std::unordered_set cache; std::unordered_map>> parents_map; - std::unordered_set needs_compile; for (int i = 0; i < inputs.size(); ++i) { auto in = inputs[i]; - needs_compile.insert(in.id()); cache.insert(in.id()); } @@ -132,16 +130,7 @@ std::pair, ParentsMap> compile_dfs( for (auto& s : a.siblings()) { cache.insert(s.id()); } - for (auto& input : a.inputs()) { - if (needs_compile.find(input.id()) != needs_compile.end()) { - tape.push_back(a); - needs_compile.insert(a.id()); - for (auto s : a.siblings()) { - needs_compile.insert(s.id()); - } - break; - } - } + tape.push_back(a); }; for (auto& a : outputs) { recurse(a); @@ -254,7 +243,7 @@ void compile_simplify( for (auto& o : outputs) { output_set.insert(o.id()); } - // Pass 1 to passes: fuse only keeping non-orphaned arrays in the tape + // Pass 1..passes: fuse only keeping non-orphaned arrays in the tape for (int pass = 0; pass < passes; ++pass) { for (auto& arr : tape) { // Helper to check if we can fuse the parents of the @@ -312,18 +301,21 @@ std::vector compile_replace( } for (auto& a : tape) { + // Arrays in the tape without primitives are constants + // and can be used directly if (!a.has_primitive()) { - std::runtime_error( - "[compile] Something went wrong, no primitive in tape"); + trace_to_real.insert({a.id(), a}); + } else { + std::vector real_inputs; + for (auto& in : a.inputs()) { + real_inputs.push_back(trace_to_real.at(in.id())); + } + auto real_a = array( + a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)); + trace_to_real.insert({a.id(), std::move(real_a)}); } - std::vector real_inputs; - for (auto& in : a.inputs()) { - real_inputs.push_back(trace_to_real.at(in.id())); - } - auto real_a = - array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)); - trace_to_real.insert({a.id(), std::move(real_a)}); } + std::vector outputs; for (auto& o : trace_outputs) { outputs.push_back(trace_to_real.at(o.id())); @@ -342,8 +334,7 @@ std::function(const std::vector&)> compile( if (entry.empty) { // Mark the entry as not empty since we are about to fill it entry.empty = false; - - // Trace te build the graph + // Trace to build the graph std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs); // DFS the graph and get a tape, and a map of array id to (parent, diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index d2f825235..dc8c3e980 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -8,13 +8,42 @@ using namespace mlx::core; std::vector simple_fun(const std::vector& inputs) { return std::vector{inputs[0] + inputs[1]}; -}; +} TEST_CASE("test simple compile") { auto compfn = compile(simple_fun); - auto out = compfn({array(1.0), array(2.0)})[0]; + auto out = compfn({array(1.0f), array(2.0f)})[0]; CHECK_EQ(out.item(), 3.0f); - out = compfn({array(1.0), array(2.0)})[0]; + out = compfn({array(1.0f), array(2.0f)})[0]; CHECK_EQ(out.item(), 3.0f); + + // Change the shapes + out = compfn({array({1.0f, 2.0f}), array(2.0f)})[0]; + CHECK(array_equal(out, array({3.0f, 4.0f})).item()); + + out = compfn({array(2.0f), array({1.0f, 2.0f})})[0]; + CHECK(array_equal(out, array({3.0f, 4.0f})).item()); + + // Change the types + out = compfn({array(2, int32), array({1.0f, 2.0f})})[0]; + CHECK(array_equal(out, array({3.0f, 4.0f})).item()); + + out = compfn({array(2.0f), array({1, 2}, int32)})[0]; + CHECK(array_equal(out, array({3.0f, 4.0f})).item()); } + +std::vector fun1(const std::vector& inputs) { + auto loss = [](std::vector ins) { return exp(ins[0] + ins[1]); }; + return grad(loss)(inputs); +} + +TEST_CASE("test compile with grad") { + auto x = array(1.0f); + auto y = array(1.0f); + auto grads_expected = fun1({x, y}); + auto grads_compile = compile(fun1)({x, y}); + CHECK_EQ(grads_compile[0].item(), grads_expected[0].item()); +} + +TEST_CASE("test nested compile") {}