mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-16 05:41:14 +08:00
Fix leak for multi-output primitives which are never detached (#1059)
* fix multi output leak * ignore arrays that will be detached * add some comments * stray print
This commit is contained in:
parent
19bef39f5c
commit
7f7b9662ea
@ -1,5 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "mlx/array.h"
|
||||
@ -167,6 +166,39 @@ void array::move_shared_buffer(array other) {
|
||||
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
array::~array() {
|
||||
if (array_desc_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore arrays that will be detached
|
||||
if (status() != array::Status::unscheduled) {
|
||||
return;
|
||||
}
|
||||
// Break circular reference for non-detached arrays with siblings
|
||||
if (auto n = siblings().size(); n > 0) {
|
||||
bool do_detach = true;
|
||||
// If all siblings have siblings.size() references except
|
||||
// the one we are currently destroying (which has siblings.size() + 1)
|
||||
// then there are no more external references
|
||||
do_detach &= (array_desc_.use_count() == (n + 1));
|
||||
for (auto& s : siblings()) {
|
||||
do_detach &= (s.array_desc_.use_count() == n);
|
||||
if (!do_detach) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (do_detach) {
|
||||
for (auto& s : siblings()) {
|
||||
for (auto& ss : s.siblings()) {
|
||||
ss.array_desc_ = nullptr;
|
||||
}
|
||||
s.array_desc_->siblings.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void array::ArrayDesc::init() {
|
||||
strides.resize(shape.size());
|
||||
size = 1;
|
||||
|
18
mlx/array.h
18
mlx/array.h
@ -261,22 +261,16 @@ class array {
|
||||
return array_desc_->siblings;
|
||||
};
|
||||
|
||||
/** The array's siblings. */
|
||||
std::vector<array>& siblings() {
|
||||
return array_desc_->siblings;
|
||||
};
|
||||
|
||||
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||
array_desc_->siblings = std::move(siblings);
|
||||
array_desc_->position = position;
|
||||
}
|
||||
|
||||
/** The i-th output of the array's primitive. */
|
||||
const array& output(int i) const {
|
||||
if (i == array_desc_->position) {
|
||||
return *this;
|
||||
} else if (i < array_desc_->position) {
|
||||
return siblings()[i];
|
||||
} else {
|
||||
return siblings()[i + 1];
|
||||
}
|
||||
};
|
||||
|
||||
/** The outputs of the array's primitive (i.e. this array and
|
||||
* its siblings) in the order the primitive expects. */
|
||||
std::vector<array> outputs() const {
|
||||
@ -386,6 +380,8 @@ class array {
|
||||
array_desc_ = other.array_desc_;
|
||||
}
|
||||
|
||||
~array();
|
||||
|
||||
private:
|
||||
// Initialize the arrays data
|
||||
template <typename It>
|
||||
|
@ -11,7 +11,7 @@ GCC=$2
|
||||
SRCDIR=$3
|
||||
CLANG=$4
|
||||
|
||||
if [ $CLANG = "TRUE" ]; then
|
||||
if [ "$CLANG" = "TRUE" ]; then
|
||||
read -r -d '' INCLUDES <<- EOM
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
|
@ -246,7 +246,7 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
return;
|
||||
}
|
||||
a.set_tracer(false);
|
||||
for (auto s : a.siblings()) {
|
||||
for (auto& s : a.siblings()) {
|
||||
s.set_tracer(false);
|
||||
cache.insert(s.id());
|
||||
}
|
||||
@ -403,7 +403,7 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
|
||||
return;
|
||||
}
|
||||
a.set_tracer(false);
|
||||
for (auto s : a.siblings()) {
|
||||
for (auto& s : a.siblings()) {
|
||||
s.set_tracer(false);
|
||||
cache.insert(s.id());
|
||||
}
|
||||
|
@ -1694,6 +1694,22 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
b = pickle.loads(pickle.dumps(a))
|
||||
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
def test_multi_output_leak(self):
|
||||
def fun():
|
||||
a = mx.zeros((2**20))
|
||||
mx.eval(a)
|
||||
b, c = mx.divmod(a, a)
|
||||
del b, c
|
||||
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_1 = mx.metal.get_peak_memory()
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_2 = mx.metal.get_peak_memory()
|
||||
self.assertEqual(peak_1, peak_2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user