Fix leak for multi-output primitives which are never detached (#1059)

* fix multi output leak

* ignore arrays that will be detached

* add some comments

* stray print
This commit is contained in:
Awni Hannun 2024-05-01 07:31:45 -07:00 committed by GitHub
parent 19bef39f5c
commit 7f7b9662ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 59 additions and 15 deletions

View File

@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <functional>
#include "mlx/array.h"
@ -167,6 +166,39 @@ void array::move_shared_buffer(array other) {
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
array::~array() {
if (array_desc_ == nullptr) {
return;
}
// Ignore arrays that will be detached
if (status() != array::Status::unscheduled) {
return;
}
// Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) {
bool do_detach = true;
// If all siblings have siblings.size() references except
// the one we are currently destroying (which has siblings.size() + 1)
// then there are no more external references
do_detach &= (array_desc_.use_count() == (n + 1));
for (auto& s : siblings()) {
do_detach &= (s.array_desc_.use_count() == n);
if (!do_detach) {
break;
}
}
if (do_detach) {
for (auto& s : siblings()) {
for (auto& ss : s.siblings()) {
ss.array_desc_ = nullptr;
}
s.array_desc_->siblings.clear();
}
}
}
}
void array::ArrayDesc::init() {
strides.resize(shape.size());
size = 1;

View File

@ -261,22 +261,16 @@ class array {
return array_desc_->siblings;
};
/** The array's siblings. */
std::vector<array>& siblings() {
return array_desc_->siblings;
};
void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings);
array_desc_->position = position;
}
/** The i-th output of the array's primitive. */
const array& output(int i) const {
if (i == array_desc_->position) {
return *this;
} else if (i < array_desc_->position) {
return siblings()[i];
} else {
return siblings()[i + 1];
}
};
/** The outputs of the array's primitive (i.e. this array and
* its siblings) in the order the primitive expects. */
std::vector<array> outputs() const {
@ -386,6 +380,8 @@ class array {
array_desc_ = other.array_desc_;
}
~array();
private:
// Initialize the arrays data
template <typename It>

View File

@ -11,7 +11,7 @@ GCC=$2
SRCDIR=$3
CLANG=$4
if [ $CLANG = "TRUE" ]; then
if [ "$CLANG" = "TRUE" ]; then
read -r -d '' INCLUDES <<- EOM
#include <cmath>
#include <complex>

View File

@ -246,7 +246,7 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
return;
}
a.set_tracer(false);
for (auto s : a.siblings()) {
for (auto& s : a.siblings()) {
s.set_tracer(false);
cache.insert(s.id());
}
@ -403,7 +403,7 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
return;
}
a.set_tracer(false);
for (auto s : a.siblings()) {
for (auto& s : a.siblings()) {
s.set_tracer(false);
cache.insert(s.id());
}

View File

@ -1694,6 +1694,22 @@ class TestArray(mlx_tests.MLXTestCase):
b = pickle.loads(pickle.dumps(a))
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_multi_output_leak(self):
def fun():
a = mx.zeros((2**20))
mx.eval(a)
b, c = mx.divmod(a, a)
del b, c
fun()
mx.synchronize()
peak_1 = mx.metal.get_peak_memory()
fun()
mx.synchronize()
peak_2 = mx.metal.get_peak_memory()
self.assertEqual(peak_1, peak_2)
if __name__ == "__main__":
unittest.main()