From 0e95b6494278ace2f5c748462d4dfc9656be355a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 11 Mar 2024 17:29:05 -0700 Subject: [PATCH] Fix bug in tape order during simplify (#816) * fix bug in tape order during simplify * properly fix compile * last bug --- mlx/compile.cpp | 19 +++++++++++++---- python/tests/test_compile.py | 40 ++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 0ede320e5..5c3fbf438 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -415,6 +415,11 @@ void compile_simplify( } tape = std::move(new_tape); + std::unordered_map tape_order; + for (uint32_t i = 0; i < tape.size(); ++i) { + tape_order.insert({tape[i].id(), i}); + } + std::unordered_set output_set; for (auto& o : outputs) { output_set.insert(o.id()); @@ -437,17 +442,23 @@ void compile_simplify( if (mask[j]) { continue; } - auto& src = parents->second[j].first; - auto& dst = parents->second[i].first; + auto src_idx = j; + auto dst_idx = i; + if (tape_order[parents->second[src_idx].first.id()] < + tape_order[parents->second[dst_idx].first.id()]) { + std::swap(src_idx, dst_idx); + } + auto& src = parents->second[src_idx].first; + auto& dst = parents->second[dst_idx].first; if (src.id() != dst.id() && array_equivalent(src, dst) && output_set.find(src.id()) == output_set.end()) { merge(dst, src, parents_map); - mask[j] = true; + mask[src_idx] = true; } } } // Erase orphaned parents so we don't keep fusing with them - for (int i = N - 1; i > 0; --i) { + for (int i = N - 1; i >= 0; --i) { if (mask[i]) { parents->second.erase(parents->second.begin() + i); } diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 98d7f7276..beba9ca95 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -613,6 +613,46 @@ class TestCompile(mlx_tests.MLXTestCase): out = fun() mx.eval(out) + def test_compile_vjp(self): + def fun(w): + w1 = w + w + w2 = w + w + return w @ w1 + w2 @ w2 + + def step(w): + out, grad = mx.vjp(fun, (w,), (mx.array([[1.0, 1.0], [1.0, 1.0]]),)) + return out[0], grad[0] + + w = mx.zeros((2, 2)) + mx.eval(w) + + expected = step(w) + out = mx.compile(step)(w) + self.assertTrue(mx.allclose(expected[0], out[0])) + self.assertTrue(mx.allclose(expected[1], out[1])) + + def fun(w1, w2, x): + x = x @ w1 + y = x @ w2 + x = x + y * y + return (x * x).sum() + + w1 = mx.zeros((4, 4)) + w2 = mx.zeros((4, 4)) + x = mx.zeros((4, 4)) + + def step(w1, w2, x): + loss, gradient = mx.value_and_grad(fun)(w1, w2, x) + w1 = w1 + gradient + return loss, w1 + + mx.eval(x, w1, w2) + expected = step(w1, w2, x) + out = mx.compile(step)(w1, w2, x) + + self.assertTrue(mx.allclose(expected[0], out[0])) + self.assertTrue(mx.allclose(expected[1], out[1])) + if __name__ == "__main__": unittest.main()