diff --git a/mlx/array.h b/mlx/array.h index d690dcd97..66a4702a6 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -339,11 +339,11 @@ class array { return allocator::allocator().size(buffer()); } - // Return a copy of the shared pointer - // to the array::Data struct - std::shared_ptr data_shared_ptr() const { + // Return the shared pointer to the array::Data struct + const std::shared_ptr& data_shared_ptr() const { return array_desc_->data; } + // Return a raw pointer to the arrays data template T* data() {