fix donation condition for compilation (#1237)

This commit is contained in:
Awni Hannun 2024-06-26 09:04:05 -07:00 committed by GitHub
parent 8c2e15e6c8
commit 5b0af4cdb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 2 deletions

View File

@ -205,8 +205,8 @@ void compiled_allocate_outputs(
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o].move_shared_buffer(

View File

@ -707,6 +707,18 @@ class TestCompile(mlx_tests.MLXTestCase):
x = mx.array([0, float("inf"), 1], dtype=mx.bfloat16)
self.assertTrue(mx.array_equal(mx.compile(fn)(x), fn(x)))
def test_max_into_equal(self):
x = mx.random.uniform(shape=(1, 2, 2))
mx.eval(x)
def fn():
maxes = mx.max(x, axis=(1, 2), keepdims=True)
return x == maxes
out = mx.compile(fn)()
expected = fn()
self.assertTrue(mx.array_equal(expected, out))
if __name__ == "__main__":
unittest.main()