mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-16 22:11:15 +08:00
Fix leak for multi-output primitives which are never detached (#1059)
* fix multi output leak * ignore arrays that will be detached * add some comments * stray print
This commit is contained in:
parent
19bef39f5c
commit
7f7b9662ea
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
@ -167,6 +166,39 @@ void array::move_shared_buffer(array other) {
|
|||||||
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array::~array() {
|
||||||
|
if (array_desc_ == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ignore arrays that will be detached
|
||||||
|
if (status() != array::Status::unscheduled) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Break circular reference for non-detached arrays with siblings
|
||||||
|
if (auto n = siblings().size(); n > 0) {
|
||||||
|
bool do_detach = true;
|
||||||
|
// If all siblings have siblings.size() references except
|
||||||
|
// the one we are currently destroying (which has siblings.size() + 1)
|
||||||
|
// then there are no more external references
|
||||||
|
do_detach &= (array_desc_.use_count() == (n + 1));
|
||||||
|
for (auto& s : siblings()) {
|
||||||
|
do_detach &= (s.array_desc_.use_count() == n);
|
||||||
|
if (!do_detach) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (do_detach) {
|
||||||
|
for (auto& s : siblings()) {
|
||||||
|
for (auto& ss : s.siblings()) {
|
||||||
|
ss.array_desc_ = nullptr;
|
||||||
|
}
|
||||||
|
s.array_desc_->siblings.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void array::ArrayDesc::init() {
|
void array::ArrayDesc::init() {
|
||||||
strides.resize(shape.size());
|
strides.resize(shape.size());
|
||||||
size = 1;
|
size = 1;
|
||||||
|
18
mlx/array.h
18
mlx/array.h
@ -261,22 +261,16 @@ class array {
|
|||||||
return array_desc_->siblings;
|
return array_desc_->siblings;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** The array's siblings. */
|
||||||
|
std::vector<array>& siblings() {
|
||||||
|
return array_desc_->siblings;
|
||||||
|
};
|
||||||
|
|
||||||
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||||
array_desc_->siblings = std::move(siblings);
|
array_desc_->siblings = std::move(siblings);
|
||||||
array_desc_->position = position;
|
array_desc_->position = position;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** The i-th output of the array's primitive. */
|
|
||||||
const array& output(int i) const {
|
|
||||||
if (i == array_desc_->position) {
|
|
||||||
return *this;
|
|
||||||
} else if (i < array_desc_->position) {
|
|
||||||
return siblings()[i];
|
|
||||||
} else {
|
|
||||||
return siblings()[i + 1];
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/** The outputs of the array's primitive (i.e. this array and
|
/** The outputs of the array's primitive (i.e. this array and
|
||||||
* its siblings) in the order the primitive expects. */
|
* its siblings) in the order the primitive expects. */
|
||||||
std::vector<array> outputs() const {
|
std::vector<array> outputs() const {
|
||||||
@ -386,6 +380,8 @@ class array {
|
|||||||
array_desc_ = other.array_desc_;
|
array_desc_ = other.array_desc_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
~array();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Initialize the arrays data
|
// Initialize the arrays data
|
||||||
template <typename It>
|
template <typename It>
|
||||||
|
@ -11,7 +11,7 @@ GCC=$2
|
|||||||
SRCDIR=$3
|
SRCDIR=$3
|
||||||
CLANG=$4
|
CLANG=$4
|
||||||
|
|
||||||
if [ $CLANG = "TRUE" ]; then
|
if [ "$CLANG" = "TRUE" ]; then
|
||||||
read -r -d '' INCLUDES <<- EOM
|
read -r -d '' INCLUDES <<- EOM
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <complex>
|
#include <complex>
|
||||||
|
@ -246,7 +246,7 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
a.set_tracer(false);
|
a.set_tracer(false);
|
||||||
for (auto s : a.siblings()) {
|
for (auto& s : a.siblings()) {
|
||||||
s.set_tracer(false);
|
s.set_tracer(false);
|
||||||
cache.insert(s.id());
|
cache.insert(s.id());
|
||||||
}
|
}
|
||||||
@ -403,7 +403,7 @@ std::pair<std::vector<array>, std::vector<array>> jvp(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
a.set_tracer(false);
|
a.set_tracer(false);
|
||||||
for (auto s : a.siblings()) {
|
for (auto& s : a.siblings()) {
|
||||||
s.set_tracer(false);
|
s.set_tracer(false);
|
||||||
cache.insert(s.id());
|
cache.insert(s.id());
|
||||||
}
|
}
|
||||||
|
@ -1694,6 +1694,22 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
b = pickle.loads(pickle.dumps(a))
|
b = pickle.loads(pickle.dumps(a))
|
||||||
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
|
self.assertTrue(mx.array_equal(mx.array(a), mx.array(b)))
|
||||||
|
|
||||||
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
|
def test_multi_output_leak(self):
|
||||||
|
def fun():
|
||||||
|
a = mx.zeros((2**20))
|
||||||
|
mx.eval(a)
|
||||||
|
b, c = mx.divmod(a, a)
|
||||||
|
del b, c
|
||||||
|
|
||||||
|
fun()
|
||||||
|
mx.synchronize()
|
||||||
|
peak_1 = mx.metal.get_peak_memory()
|
||||||
|
fun()
|
||||||
|
mx.synchronize()
|
||||||
|
peak_2 = mx.metal.get_peak_memory()
|
||||||
|
self.assertEqual(peak_1, peak_2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user