mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
shapeless slice update and broadcast when possible (#1727)
This commit is contained in:
@@ -817,6 +817,19 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
fun = mx.compile(lambda a, b: a @ b, shapeless=True)
|
||||
self.assertTrue(mx.allclose(fun(a, b), a @ b))
|
||||
|
||||
def test_shapeless_compile_slice_update(self):
|
||||
def fun(x):
|
||||
x[2] = mx.array([3.0])
|
||||
return x
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
|
||||
a = mx.array([0.0, 1.0, 2.0, 3.0])
|
||||
self.assertTrue(mx.allclose(cfun(a), fun(a)))
|
||||
|
||||
a = mx.array([0.0, 1.0, 2.0, 3.0, 4.0])
|
||||
self.assertTrue(mx.allclose(cfun(a), fun(a)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user