mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
fix multi output leak (#1548)
This commit is contained in:
parent
cde5b4ad80
commit
57c6aa7188
@ -271,6 +271,9 @@ array::ArrayDesc::~ArrayDesc() {
|
|||||||
for (array& a : ad.inputs) {
|
for (array& a : ad.inputs) {
|
||||||
if (a.array_desc_) {
|
if (a.array_desc_) {
|
||||||
input_map.insert({a.id(), a});
|
input_map.insert({a.id(), a});
|
||||||
|
for (auto& s : a.siblings()) {
|
||||||
|
input_map.insert({s.id(), s});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ad.inputs.clear();
|
ad.inputs.clear();
|
||||||
|
@ -1771,6 +1771,19 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
peak_2 = mx.metal.get_peak_memory()
|
peak_2 = mx.metal.get_peak_memory()
|
||||||
self.assertEqual(peak_1, peak_2)
|
self.assertEqual(peak_1, peak_2)
|
||||||
|
|
||||||
|
def fun():
|
||||||
|
a = mx.array([1.0, 2.0, 3.0, 4.0])
|
||||||
|
b, _ = mx.divmod(a, a)
|
||||||
|
return mx.log(b)
|
||||||
|
|
||||||
|
fun()
|
||||||
|
mx.synchronize()
|
||||||
|
peak_1 = mx.metal.get_peak_memory()
|
||||||
|
fun()
|
||||||
|
mx.synchronize()
|
||||||
|
peak_2 = mx.metal.get_peak_memory()
|
||||||
|
self.assertEqual(peak_1, peak_2)
|
||||||
|
|
||||||
def test_add_numpy(self):
|
def test_add_numpy(self):
|
||||||
x = mx.array(1)
|
x = mx.array(1)
|
||||||
y = np.array(2, dtype=np.int32)
|
y = np.array(2, dtype=np.int32)
|
||||||
|
Loading…
Reference in New Issue
Block a user