From 5b0af4cdb109c9b02bd897f71d99190cd588cd3c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 26 Jun 2024 09:04:05 -0700 Subject: [PATCH] fix donation condition for compilation (#1237) --- mlx/backend/common/compiled.cpp | 4 ++-- python/tests/test_compile.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index 762939b1e..e847017c7 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -205,8 +205,8 @@ void compiled_allocate_outputs( // - Donatable // - Correct size // - Not a constant - if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() && - in.is_donatable() && + if (in.flags().row_contiguous && in.size() == outputs[o].size() && + in.itemsize() == outputs[o].itemsize() && in.is_donatable() && constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { if (move_buffers) { outputs[o].move_shared_buffer( diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index fb058203e..82773dbf2 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -707,6 +707,18 @@ class TestCompile(mlx_tests.MLXTestCase): x = mx.array([0, float("inf"), 1], dtype=mx.bfloat16) self.assertTrue(mx.array_equal(mx.compile(fn)(x), fn(x))) + def test_max_into_equal(self): + x = mx.random.uniform(shape=(1, 2, 2)) + mx.eval(x) + + def fn(): + maxes = mx.max(x, axis=(1, 2), keepdims=True) + return x == maxes + + out = mx.compile(fn)() + expected = fn() + self.assertTrue(mx.array_equal(expected, out)) + if __name__ == "__main__": unittest.main()