From 7f7b9662eac8cdb2ddfef993006bd42a8a3f2dd7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 1 May 2024 07:31:45 -0700 Subject: [PATCH] Fix leak for multi-output primitives which are never detached (#1059) * fix multi output leak * ignore arrays that will be detached * add some comments * stray print --- mlx/array.cpp | 34 +++++++++++++++++++- mlx/array.h | 18 ++++------- mlx/backend/common/make_compiled_preamble.sh | 2 +- mlx/transforms.cpp | 4 +-- python/tests/test_array.py | 16 +++++++++ 5 files changed, 59 insertions(+), 15 deletions(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index f655bc6a6..3f463ac19 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -1,5 +1,4 @@ // Copyright © 2023-2024 Apple Inc. - #include #include "mlx/array.h" @@ -167,6 +166,39 @@ void array::move_shared_buffer(array other) { move_shared_buffer(other, other.strides(), other.flags(), other.data_size()); } +array::~array() { + if (array_desc_ == nullptr) { + return; + } + + // Ignore arrays that will be detached + if (status() != array::Status::unscheduled) { + return; + } + // Break circular reference for non-detached arrays with siblings + if (auto n = siblings().size(); n > 0) { + bool do_detach = true; + // If all siblings have siblings.size() references except + // the one we are currently destroying (which has siblings.size() + 1) + // then there are no more external references + do_detach &= (array_desc_.use_count() == (n + 1)); + for (auto& s : siblings()) { + do_detach &= (s.array_desc_.use_count() == n); + if (!do_detach) { + break; + } + } + if (do_detach) { + for (auto& s : siblings()) { + for (auto& ss : s.siblings()) { + ss.array_desc_ = nullptr; + } + s.array_desc_->siblings.clear(); + } + } + } +} + void array::ArrayDesc::init() { strides.resize(shape.size()); size = 1; diff --git a/mlx/array.h b/mlx/array.h index b576ff2c6..96b6b971e 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -261,22 +261,16 @@ class array { return array_desc_->siblings; }; + /** The array's siblings. */ + std::vector& siblings() { + return array_desc_->siblings; + }; + void set_siblings(std::vector siblings, uint16_t position) { array_desc_->siblings = std::move(siblings); array_desc_->position = position; } - /** The i-th output of the array's primitive. */ - const array& output(int i) const { - if (i == array_desc_->position) { - return *this; - } else if (i < array_desc_->position) { - return siblings()[i]; - } else { - return siblings()[i + 1]; - } - }; - /** The outputs of the array's primitive (i.e. this array and * its siblings) in the order the primitive expects. */ std::vector outputs() const { @@ -386,6 +380,8 @@ class array { array_desc_ = other.array_desc_; } + ~array(); + private: // Initialize the arrays data template diff --git a/mlx/backend/common/make_compiled_preamble.sh b/mlx/backend/common/make_compiled_preamble.sh index 687f4cfc7..050fce25e 100644 --- a/mlx/backend/common/make_compiled_preamble.sh +++ b/mlx/backend/common/make_compiled_preamble.sh @@ -11,7 +11,7 @@ GCC=$2 SRCDIR=$3 CLANG=$4 -if [ $CLANG = "TRUE" ]; then +if [ "$CLANG" = "TRUE" ]; then read -r -d '' INCLUDES <<- EOM #include #include diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index ace64a14a..005402a98 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -246,7 +246,7 @@ std::pair, std::vector> vjp( return; } a.set_tracer(false); - for (auto s : a.siblings()) { + for (auto& s : a.siblings()) { s.set_tracer(false); cache.insert(s.id()); } @@ -403,7 +403,7 @@ std::pair, std::vector> jvp( return; } a.set_tracer(false); - for (auto s : a.siblings()) { + for (auto& s : a.siblings()) { s.set_tracer(false); cache.insert(s.id()); } diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 86478c33b..4faa3ec1c 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1694,6 +1694,22 @@ class TestArray(mlx_tests.MLXTestCase): b = pickle.loads(pickle.dumps(a)) self.assertTrue(mx.array_equal(mx.array(a), mx.array(b))) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_multi_output_leak(self): + def fun(): + a = mx.zeros((2**20)) + mx.eval(a) + b, c = mx.divmod(a, a) + del b, c + + fun() + mx.synchronize() + peak_1 = mx.metal.get_peak_memory() + fun() + mx.synchronize() + peak_2 = mx.metal.get_peak_memory() + self.assertEqual(peak_1, peak_2) + if __name__ == "__main__": unittest.main()