mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	shapeless slice update and broadcast when possible (#1727)
This commit is contained in:
		| @@ -817,6 +817,19 @@ class TestCompile(mlx_tests.MLXTestCase): | ||||
|         fun = mx.compile(lambda a, b: a @ b, shapeless=True) | ||||
|         self.assertTrue(mx.allclose(fun(a, b), a @ b)) | ||||
|  | ||||
|     def test_shapeless_compile_slice_update(self): | ||||
|         def fun(x): | ||||
|             x[2] = mx.array([3.0]) | ||||
|             return x | ||||
|  | ||||
|         cfun = mx.compile(fun, shapeless=True) | ||||
|  | ||||
|         a = mx.array([0.0, 1.0, 2.0, 3.0]) | ||||
|         self.assertTrue(mx.allclose(cfun(a), fun(a))) | ||||
|  | ||||
|         a = mx.array([0.0, 1.0, 2.0, 3.0, 4.0]) | ||||
|         self.assertTrue(mx.allclose(cfun(a), fun(a))) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun