From 66fcb9fe94125fcf61a2f0fccc255220cc98c669 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Wed, 29 Oct 2025 16:02:31 -0700 Subject: [PATCH] array: use int or int64_t instead of size_t --- mlx/array.cpp | 15 ++++++++------- mlx/array.h | 14 +++++++------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index a05e8dfa7..23b322a7b 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -44,11 +44,11 @@ std::vector array::make_arrays( const std::shared_ptr& primitive, const std::vector& inputs) { std::vector outputs; - for (size_t i = 0; i < shapes.size(); ++i) { + for (int i = 0; i < std::ssize(shapes); ++i) { outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs); } // For each node in |outputs|, its siblings are the other nodes. - for (size_t i = 0; i < outputs.size(); ++i) { + for (int i = 0; i < std::ssize(outputs); ++i) { auto siblings = outputs; siblings.erase(siblings.begin() + i); outputs[i].set_siblings(std::move(siblings), i); @@ -145,8 +145,9 @@ void array::set_data(allocator::Buffer buffer, Deleter d) { array_desc_->data_size = size(); array_desc_->flags.contiguous = true; array_desc_->flags.row_contiguous = true; - auto max_dim = std::max_element(shape().begin(), shape().end()); - array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim; + auto max_dim = + static_cast(*std::max_element(shape().begin(), shape().end())); + array_desc_->flags.col_contiguous = size() <= 1 || size() == max_dim; } void array::set_data( @@ -192,7 +193,7 @@ array::~array() { } // Break circular reference for non-detached arrays with siblings - if (auto n = siblings().size(); n > 0) { + if (auto n = std::ssize(siblings()); n > 0) { bool do_detach = true; // If all siblings have siblings.size() references except // the one we are currently destroying (which has siblings.size() + 1) @@ -274,7 +275,7 @@ array::ArrayDesc::~ArrayDesc() { ad.inputs.clear(); for (auto& [_, a] : input_map) { bool is_deletable = - (a.array_desc_.use_count() <= a.siblings().size() + 1); + (a.array_desc_.use_count() <= std::ssize(a.siblings()) + 1); // An array with siblings is deletable only if all of its siblings // are deletable for (auto& s : a.siblings()) { @@ -283,7 +284,7 @@ array::ArrayDesc::~ArrayDesc() { } int is_input = (input_map.find(s.id()) != input_map.end()); is_deletable &= - s.array_desc_.use_count() <= a.siblings().size() + is_input; + s.array_desc_.use_count() <= std::ssize(a.siblings()) + is_input; } if (is_deletable) { for_deletion.push_back(std::move(a.array_desc_)); diff --git a/mlx/array.h b/mlx/array.h index 4e9a5ae63..4be4835e1 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -81,22 +81,22 @@ class array { } /** The size of the array's datatype in bytes. */ - size_t itemsize() const { + int itemsize() const { return size_of(dtype()); } /** The number of elements in the array. */ - size_t size() const { + int64_t size() const { return array_desc_->size; } /** The number of bytes in the array. */ - size_t nbytes() const { + int64_t nbytes() const { return size() * itemsize(); } /** The number of dimensions of the array. */ - size_t ndim() const { + int ndim() const { return array_desc_->shape.size(); } @@ -329,7 +329,7 @@ class array { * corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``. * Note, ``data_size`` is in units of ``item_size`` (not bytes). **/ - size_t data_size() const { + int64_t data_size() const { return array_desc_->data_size; } @@ -340,7 +340,7 @@ class array { return array_desc_->data->buffer; } - size_t buffer_size() const { + int64_t buffer_size() const { return allocator::allocator().size(buffer()); } @@ -530,7 +530,7 @@ array::array( Shape shape, Dtype dtype /* = TypeToDtype() */) : array_desc_(std::make_shared(std::move(shape), dtype)) { - if (data.size() != size()) { + if (std::ssize(data) != size()) { throw std::invalid_argument( "Data size and provided shape mismatch in array construction."); }