From 57c6aa7188b45949228ded8e60b1daabf1bedcfb Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 31 Oct 2024 09:32:01 -0700 Subject: [PATCH] fix multi output leak (#1548) --- mlx/array.cpp | 3 +++ python/tests/test_array.py | 13 +++++++++++++ 2 files changed, 16 insertions(+) diff --git a/mlx/array.cpp b/mlx/array.cpp index bb92989c3..2c70b9c8e 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -271,6 +271,9 @@ array::ArrayDesc::~ArrayDesc() { for (array& a : ad.inputs) { if (a.array_desc_) { input_map.insert({a.id(), a}); + for (auto& s : a.siblings()) { + input_map.insert({s.id(), s}); + } } } ad.inputs.clear(); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index da14675a0..db0e26aa9 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1771,6 +1771,19 @@ class TestArray(mlx_tests.MLXTestCase): peak_2 = mx.metal.get_peak_memory() self.assertEqual(peak_1, peak_2) + def fun(): + a = mx.array([1.0, 2.0, 3.0, 4.0]) + b, _ = mx.divmod(a, a) + return mx.log(b) + + fun() + mx.synchronize() + peak_1 = mx.metal.get_peak_memory() + fun() + mx.synchronize() + peak_2 = mx.metal.get_peak_memory() + self.assertEqual(peak_1, peak_2) + def test_add_numpy(self): x = mx.array(1) y = np.array(2, dtype=np.int32)