mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	fix donation in scan (#1917)
This commit is contained in:
		| @@ -1894,6 +1894,22 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|             expected = mx.repeat(expected[:, None], 2, axis=1) | ||||
|             self.assertTrue(mx.array_equal(expected, out)) | ||||
|  | ||||
|         # Test donation | ||||
|         def fn(its): | ||||
|             x = mx.ones((32,)) | ||||
|             for _ in range(its): | ||||
|                 x = mx.cumsum(x) | ||||
|             return x | ||||
|  | ||||
|         mx.synchronize(mx.default_stream(mx.default_device())) | ||||
|         mx.eval(fn(2)) | ||||
|         mx.synchronize(mx.default_stream(mx.default_device())) | ||||
|         mem2 = mx.metal.get_peak_memory() | ||||
|         mx.eval(fn(4)) | ||||
|         mx.synchronize(mx.default_stream(mx.default_device())) | ||||
|         mem4 = mx.metal.get_peak_memory() | ||||
|         self.assertEqual(mem2, mem4) | ||||
|  | ||||
|     def test_squeeze_expand(self): | ||||
|         a = mx.zeros((2, 1, 2, 1)) | ||||
|         self.assertEqual(mx.squeeze(a).shape, (2, 2)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun