fix multi output leak (#1548)

This commit is contained in:
Awni Hannun 2024-10-31 09:32:01 -07:00 committed by GitHub
parent cde5b4ad80
commit 57c6aa7188
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 0 deletions

View File

@ -271,6 +271,9 @@ array::ArrayDesc::~ArrayDesc() {
for (array& a : ad.inputs) { for (array& a : ad.inputs) {
if (a.array_desc_) { if (a.array_desc_) {
input_map.insert({a.id(), a}); input_map.insert({a.id(), a});
for (auto& s : a.siblings()) {
input_map.insert({s.id(), s});
}
} }
} }
ad.inputs.clear(); ad.inputs.clear();

View File

@ -1771,6 +1771,19 @@ class TestArray(mlx_tests.MLXTestCase):
peak_2 = mx.metal.get_peak_memory() peak_2 = mx.metal.get_peak_memory()
self.assertEqual(peak_1, peak_2) self.assertEqual(peak_1, peak_2)
def fun():
a = mx.array([1.0, 2.0, 3.0, 4.0])
b, _ = mx.divmod(a, a)
return mx.log(b)
fun()
mx.synchronize()
peak_1 = mx.metal.get_peak_memory()
fun()
mx.synchronize()
peak_2 = mx.metal.get_peak_memory()
self.assertEqual(peak_1, peak_2)
def test_add_numpy(self): def test_add_numpy(self):
x = mx.array(1) x = mx.array(1)
y = np.array(2, dtype=np.int32) y = np.array(2, dtype=np.int32)