fix multi output leak (#1548)

This commit is contained in:
Awni Hannun
2024-10-31 09:32:01 -07:00
committed by GitHub
parent cde5b4ad80
commit 57c6aa7188
2 changed files with 16 additions and 0 deletions

View File

@@ -1771,6 +1771,19 @@ class TestArray(mlx_tests.MLXTestCase):
peak_2 = mx.metal.get_peak_memory()
self.assertEqual(peak_1, peak_2)
def fun():
a = mx.array([1.0, 2.0, 3.0, 4.0])
b, _ = mx.divmod(a, a)
return mx.log(b)
fun()
mx.synchronize()
peak_1 = mx.metal.get_peak_memory()
fun()
mx.synchronize()
peak_2 = mx.metal.get_peak_memory()
self.assertEqual(peak_1, peak_2)
def test_add_numpy(self):
x = mx.array(1)
y = np.array(2, dtype=np.int32)