array: use int or int64_t instead of size_t

This commit is contained in:
Ronan Collobert
2025-10-29 16:02:31 -07:00
parent d1e06117e8
commit 66fcb9fe94
2 changed files with 15 additions and 14 deletions

View File

@@ -44,11 +44,11 @@ std::vector<array> array::make_arrays(
const std::shared_ptr<Primitive>& primitive, const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs) { const std::vector<array>& inputs) {
std::vector<array> outputs; std::vector<array> outputs;
for (size_t i = 0; i < shapes.size(); ++i) { for (int i = 0; i < std::ssize(shapes); ++i) {
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs); outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
} }
// For each node in |outputs|, its siblings are the other nodes. // For each node in |outputs|, its siblings are the other nodes.
for (size_t i = 0; i < outputs.size(); ++i) { for (int i = 0; i < std::ssize(outputs); ++i) {
auto siblings = outputs; auto siblings = outputs;
siblings.erase(siblings.begin() + i); siblings.erase(siblings.begin() + i);
outputs[i].set_siblings(std::move(siblings), i); outputs[i].set_siblings(std::move(siblings), i);
@@ -145,8 +145,9 @@ void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data_size = size(); array_desc_->data_size = size();
array_desc_->flags.contiguous = true; array_desc_->flags.contiguous = true;
array_desc_->flags.row_contiguous = true; array_desc_->flags.row_contiguous = true;
auto max_dim = std::max_element(shape().begin(), shape().end()); auto max_dim =
array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim; static_cast<int64_t>(*std::max_element(shape().begin(), shape().end()));
array_desc_->flags.col_contiguous = size() <= 1 || size() == max_dim;
} }
void array::set_data( void array::set_data(
@@ -192,7 +193,7 @@ array::~array() {
} }
// Break circular reference for non-detached arrays with siblings // Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) { if (auto n = std::ssize(siblings()); n > 0) {
bool do_detach = true; bool do_detach = true;
// If all siblings have siblings.size() references except // If all siblings have siblings.size() references except
// the one we are currently destroying (which has siblings.size() + 1) // the one we are currently destroying (which has siblings.size() + 1)
@@ -274,7 +275,7 @@ array::ArrayDesc::~ArrayDesc() {
ad.inputs.clear(); ad.inputs.clear();
for (auto& [_, a] : input_map) { for (auto& [_, a] : input_map) {
bool is_deletable = bool is_deletable =
(a.array_desc_.use_count() <= a.siblings().size() + 1); (a.array_desc_.use_count() <= std::ssize(a.siblings()) + 1);
// An array with siblings is deletable only if all of its siblings // An array with siblings is deletable only if all of its siblings
// are deletable // are deletable
for (auto& s : a.siblings()) { for (auto& s : a.siblings()) {
@@ -283,7 +284,7 @@ array::ArrayDesc::~ArrayDesc() {
} }
int is_input = (input_map.find(s.id()) != input_map.end()); int is_input = (input_map.find(s.id()) != input_map.end());
is_deletable &= is_deletable &=
s.array_desc_.use_count() <= a.siblings().size() + is_input; s.array_desc_.use_count() <= std::ssize(a.siblings()) + is_input;
} }
if (is_deletable) { if (is_deletable) {
for_deletion.push_back(std::move(a.array_desc_)); for_deletion.push_back(std::move(a.array_desc_));

View File

@@ -81,22 +81,22 @@ class array {
} }
/** The size of the array's datatype in bytes. */ /** The size of the array's datatype in bytes. */
size_t itemsize() const { int itemsize() const {
return size_of(dtype()); return size_of(dtype());
} }
/** The number of elements in the array. */ /** The number of elements in the array. */
size_t size() const { int64_t size() const {
return array_desc_->size; return array_desc_->size;
} }
/** The number of bytes in the array. */ /** The number of bytes in the array. */
size_t nbytes() const { int64_t nbytes() const {
return size() * itemsize(); return size() * itemsize();
} }
/** The number of dimensions of the array. */ /** The number of dimensions of the array. */
size_t ndim() const { int ndim() const {
return array_desc_->shape.size(); return array_desc_->shape.size();
} }
@@ -329,7 +329,7 @@ class array {
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``. * corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
* Note, ``data_size`` is in units of ``item_size`` (not bytes). * Note, ``data_size`` is in units of ``item_size`` (not bytes).
**/ **/
size_t data_size() const { int64_t data_size() const {
return array_desc_->data_size; return array_desc_->data_size;
} }
@@ -340,7 +340,7 @@ class array {
return array_desc_->data->buffer; return array_desc_->data->buffer;
} }
size_t buffer_size() const { int64_t buffer_size() const {
return allocator::allocator().size(buffer()); return allocator::allocator().size(buffer());
} }
@@ -530,7 +530,7 @@ array::array(
Shape shape, Shape shape,
Dtype dtype /* = TypeToDtype<T>() */) Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) { : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
if (data.size() != size()) { if (std::ssize(data) != size()) {
throw std::invalid_argument( throw std::invalid_argument(
"Data size and provided shape mismatch in array construction."); "Data size and provided shape mismatch in array construction.");
} }