diff --git a/mlx/array.cpp b/mlx/array.cpp index edbaad395..ff058a833 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -192,6 +192,36 @@ array::ArrayDesc::ArrayDesc( init(); } +array::ArrayDesc::~ArrayDesc() { + // When an array description is destroyed it will delete a bunch of arrays + // that may also destory their corresponding descriptions and so on and so + // forth. + // + // This calls recursively the destructor and can result in stack overflow, we + // instead put them in a vector and destroy them one at a time resulting in a + // max stack depth of 2. + std::vector> for_deletion; + + for (array& a : inputs) { + if (a.array_desc_.use_count() == 1) { + for_deletion.push_back(std::move(a.array_desc_)); + } + } + + while (!for_deletion.empty()) { + // top is going to be deleted at the end of the block *after* the arrays + // with inputs have been moved into the vector + auto top = std::move(for_deletion.back()); + for_deletion.pop_back(); + + for (array& a : top->inputs) { + if (a.array_desc_.use_count() == 1) { + for_deletion.push_back(std::move(a.array_desc_)); + } + } + } +} + array::ArrayIterator::ArrayIterator(const array& arr, int idx) : arr(arr), idx(idx) { if (arr.ndim() == 0) { diff --git a/mlx/array.h b/mlx/array.h index b142f90c8..a3b2b2c44 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -404,6 +404,8 @@ class array { std::shared_ptr primitive, std::vector inputs); + ~ArrayDesc(); + private: // Initialize size, strides, and other metadata void init();