shapeless slice update and broadcast when possible (#1727)

This commit is contained in:
Awni Hannun
2024-12-23 11:25:15 -08:00
committed by GitHub
parent 0308e9af71
commit ebfe64b92d
6 changed files with 43 additions and 99 deletions

View File

@@ -817,6 +817,19 @@ class TestCompile(mlx_tests.MLXTestCase):
fun = mx.compile(lambda a, b: a @ b, shapeless=True)
self.assertTrue(mx.allclose(fun(a, b), a @ b))
def test_shapeless_compile_slice_update(self):
def fun(x):
x[2] = mx.array([3.0])
return x
cfun = mx.compile(fun, shapeless=True)
a = mx.array([0.0, 1.0, 2.0, 3.0])
self.assertTrue(mx.allclose(cfun(a), fun(a)))
a = mx.array([0.0, 1.0, 2.0, 3.0, 4.0])
self.assertTrue(mx.allclose(cfun(a), fun(a)))
if __name__ == "__main__":
unittest.main()