mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix multi output leak (#1548)
This commit is contained in:
parent
cde5b4ad80
commit
57c6aa7188
@ -271,6 +271,9 @@ array::ArrayDesc::~ArrayDesc() {
|
||||
for (array& a : ad.inputs) {
|
||||
if (a.array_desc_) {
|
||||
input_map.insert({a.id(), a});
|
||||
for (auto& s : a.siblings()) {
|
||||
input_map.insert({s.id(), s});
|
||||
}
|
||||
}
|
||||
}
|
||||
ad.inputs.clear();
|
||||
|
@ -1771,6 +1771,19 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
peak_2 = mx.metal.get_peak_memory()
|
||||
self.assertEqual(peak_1, peak_2)
|
||||
|
||||
def fun():
|
||||
a = mx.array([1.0, 2.0, 3.0, 4.0])
|
||||
b, _ = mx.divmod(a, a)
|
||||
return mx.log(b)
|
||||
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_1 = mx.metal.get_peak_memory()
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_2 = mx.metal.get_peak_memory()
|
||||
self.assertEqual(peak_1, peak_2)
|
||||
|
||||
def test_add_numpy(self):
|
||||
x = mx.array(1)
|
||||
y = np.array(2, dtype=np.int32)
|
||||
|
Loading…
Reference in New Issue
Block a user