fix donation condition for compilation (#1237)

This commit is contained in:
Awni Hannun
2024-06-26 09:04:05 -07:00
committed by GitHub
parent 8c2e15e6c8
commit 5b0af4cdb1
2 changed files with 14 additions and 2 deletions

View File

@@ -707,6 +707,18 @@ class TestCompile(mlx_tests.MLXTestCase):
x = mx.array([0, float("inf"), 1], dtype=mx.bfloat16)
self.assertTrue(mx.array_equal(mx.compile(fn)(x), fn(x)))
def test_max_into_equal(self):
x = mx.random.uniform(shape=(1, 2, 2))
mx.eval(x)
def fn():
maxes = mx.max(x, axis=(1, 2), keepdims=True)
return x == maxes
out = mx.compile(fn)()
expected = fn()
self.assertTrue(mx.array_equal(expected, out))
if __name__ == "__main__":
unittest.main()