mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	more donation take 2
This commit is contained in:
		@@ -195,6 +195,36 @@ class TestEval(mlx_tests.MLXTestCase):
 | 
			
		||||
        mx.eval(z)
 | 
			
		||||
        mx.set_memory_limit(old_limit)
 | 
			
		||||
 | 
			
		||||
    def test_donation_multiple_inputs(self):
 | 
			
		||||
        def fun(its, x, y):
 | 
			
		||||
            for _ in range(its):
 | 
			
		||||
                a = x + y  # y should donate
 | 
			
		||||
                b = x + a  # x should donate
 | 
			
		||||
                x, y = a, b
 | 
			
		||||
            return x, y
 | 
			
		||||
 | 
			
		||||
        x = mx.zeros((128, 128))
 | 
			
		||||
        y = mx.zeros((128, 128))
 | 
			
		||||
        mx.reset_peak_memory()
 | 
			
		||||
        a, b = fun(2, x, y)
 | 
			
		||||
        mx.eval(a, b)
 | 
			
		||||
        mx.synchronize()
 | 
			
		||||
        mem2 = mx.get_peak_memory()
 | 
			
		||||
        a, b = fun(10, x, y)
 | 
			
		||||
        mx.eval(a, b)
 | 
			
		||||
        mx.synchronize()
 | 
			
		||||
        mem10 = mx.get_peak_memory()
 | 
			
		||||
        self.assertEqual(mem2, mem10)
 | 
			
		||||
 | 
			
		||||
    def test_async_with_delete(self):
 | 
			
		||||
        a = mx.ones((5, 5))
 | 
			
		||||
        for _ in range(100):
 | 
			
		||||
            a = mx.abs(a)
 | 
			
		||||
        mx.async_eval(a)
 | 
			
		||||
        del a
 | 
			
		||||
        mx.clear_cache()
 | 
			
		||||
        mx.synchronize()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user