mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
Fix leak for multi-output primitives which are never detached (#1059)
* fix multi output leak * ignore arrays that will be detached * add some comments * stray print
This commit is contained in:
@@ -1694,6 +1694,22 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
b = pickle.loads(pickle.dumps(a))
|
||||
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_multi_output_leak(self):
|
||||
def fun():
|
||||
a = mx.zeros((2**20))
|
||||
mx.eval(a)
|
||||
b, c = mx.divmod(a, a)
|
||||
del b, c
|
||||
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_1 = mx.metal.get_peak_memory()
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_2 = mx.metal.get_peak_memory()
|
||||
self.assertEqual(peak_1, peak_2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user