array: use int or int64_t instead of size_t

This commit is contained in:
Ronan Collobert
2025-10-29 16:02:31 -07:00
parent d1e06117e8
commit 66fcb9fe94
2 changed files with 15 additions and 14 deletions

View File

@@ -44,11 +44,11 @@ std::vector<array> array::make_arrays(
const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs) {
std::vector<array> outputs;
for (size_t i = 0; i < shapes.size(); ++i) {
for (int i = 0; i < std::ssize(shapes); ++i) {
outputs.emplace_back(std::move(shapes[i]), dtypes[i], primitive, inputs);
}
// For each node in |outputs|, its siblings are the other nodes.
for (size_t i = 0; i < outputs.size(); ++i) {
for (int i = 0; i < std::ssize(outputs); ++i) {
auto siblings = outputs;
siblings.erase(siblings.begin() + i);
outputs[i].set_siblings(std::move(siblings), i);
@@ -145,8 +145,9 @@ void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data_size = size();
array_desc_->flags.contiguous = true;
array_desc_->flags.row_contiguous = true;
auto max_dim = std::max_element(shape().begin(), shape().end());
array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim;
auto max_dim =
static_cast<int64_t>(*std::max_element(shape().begin(), shape().end()));
array_desc_->flags.col_contiguous = size() <= 1 || size() == max_dim;
}
void array::set_data(
@@ -192,7 +193,7 @@ array::~array() {
}
// Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) {
if (auto n = std::ssize(siblings()); n > 0) {
bool do_detach = true;
// If all siblings have siblings.size() references except
// the one we are currently destroying (which has siblings.size() + 1)
@@ -274,7 +275,7 @@ array::ArrayDesc::~ArrayDesc() {
ad.inputs.clear();
for (auto& [_, a] : input_map) {
bool is_deletable =
(a.array_desc_.use_count() <= a.siblings().size() + 1);
(a.array_desc_.use_count() <= std::ssize(a.siblings()) + 1);
// An array with siblings is deletable only if all of its siblings
// are deletable
for (auto& s : a.siblings()) {
@@ -283,7 +284,7 @@ array::ArrayDesc::~ArrayDesc() {
}
int is_input = (input_map.find(s.id()) != input_map.end());
is_deletable &=
s.array_desc_.use_count() <= a.siblings().size() + is_input;
s.array_desc_.use_count() <= std::ssize(a.siblings()) + is_input;
}
if (is_deletable) {
for_deletion.push_back(std::move(a.array_desc_));

View File

@@ -81,22 +81,22 @@ class array {
}
/** The size of the array's datatype in bytes. */
size_t itemsize() const {
int itemsize() const {
return size_of(dtype());
}
/** The number of elements in the array. */
size_t size() const {
int64_t size() const {
return array_desc_->size;
}
/** The number of bytes in the array. */
size_t nbytes() const {
int64_t nbytes() const {
return size() * itemsize();
}
/** The number of dimensions of the array. */
size_t ndim() const {
int ndim() const {
return array_desc_->shape.size();
}
@@ -329,7 +329,7 @@ class array {
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
* Note, ``data_size`` is in units of ``item_size`` (not bytes).
**/
size_t data_size() const {
int64_t data_size() const {
return array_desc_->data_size;
}
@@ -340,7 +340,7 @@ class array {
return array_desc_->data->buffer;
}
size_t buffer_size() const {
int64_t buffer_size() const {
return allocator::allocator().size(buffer());
}
@@ -530,7 +530,7 @@ array::array(
Shape shape,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
if (data.size() != size()) {
if (std::ssize(data) != size()) {
throw std::invalid_argument(
"Data size and provided shape mismatch in array construction.");
}