mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix leak with multi-output primitives (#1274)
* fix leak with multi-output primitives * hopefully an actual fix
This commit is contained in:
		| @@ -103,6 +103,23 @@ class TestEval(mlx_tests.MLXTestCase): | ||||
|         z = mx.add(y, x, stream=mx.cpu) | ||||
|         self.assertTrue(mx.allclose(z, mx.full((8000,), 22.0))) | ||||
|  | ||||
|     def test_multi_output_eval_during_transform(self): | ||||
|         x = mx.random.uniform(shape=(1024,)) | ||||
|         y = mx.ones((1024,)) | ||||
|         mx.eval(x, y) | ||||
|  | ||||
|         def fn(x): | ||||
|             a, b = mx.divmod(x, x) | ||||
|             mx.eval(a) | ||||
|             return a | ||||
|  | ||||
|         out = mx.vjp(fn, (x,), (y,)) | ||||
|         out = mx.vjp(fn, (x,), (y,)) | ||||
|         if mx.metal.is_available(): | ||||
|             peak_mem = mx.metal.get_peak_memory() | ||||
|             out = mx.vjp(fn, (x,), (y,)) | ||||
|             self.assertEqual(peak_mem, mx.metal.get_peak_memory()) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun