SmallVector: keep sizes small (int)

This commit is contained in:
Ronan Collobert
2025-10-29 16:06:10 -07:00
parent 66fcb9fe94
commit 26f7155537

View File

@@ -121,10 +121,10 @@ class SmallVector {
std::initializer_list<T> init, std::initializer_list<T> init,
const Allocator& allocator = Allocator()) const Allocator& allocator = Allocator())
: allocator_(allocator) { : allocator_(allocator) {
if (init.size() > capacity()) { if (static_cast<int>(init.size()) > capacity()) {
grow(init.size()); grow(init.size());
} }
assert(capacity() >= init.size()); // sanity check assert(capacity() >= static_cast<int>(init.size())); // sanity check
std::uninitialized_move(init.begin(), init.end(), begin_); std::uninitialized_move(init.begin(), init.end(), begin_);
end_ = begin_ + init.size(); end_ = begin_ + init.size();
} }
@@ -132,7 +132,7 @@ class SmallVector {
template <typename Iter, typename = std::enable_if_t<is_iterator_v<Iter>>> template <typename Iter, typename = std::enable_if_t<is_iterator_v<Iter>>>
SmallVector(Iter begin, Iter end, const Allocator& allocator = Allocator()) SmallVector(Iter begin, Iter end, const Allocator& allocator = Allocator())
: allocator_(allocator) { : allocator_(allocator) {
size_t size = std::distance(begin, end); int size = std::distance(begin, end);
if (size > capacity()) { if (size > capacity()) {
grow(size); grow(size);
} }
@@ -164,7 +164,7 @@ class SmallVector {
if (this == &other) { if (this == &other) {
return *this; return *this;
} }
size_t other_size = other.size(); int other_size = other.size();
if (capacity() < other_size) { if (capacity() < other_size) {
// Create large-enough heap-allocated storage. // Create large-enough heap-allocated storage.
free_storage(); free_storage();
@@ -273,13 +273,13 @@ class SmallVector {
return std::make_reverse_iterator(begin_); return std::make_reverse_iterator(begin_);
} }
size_t size() const { int size() const {
return end_ - begin_; return end_ - begin_;
} }
bool empty() const { bool empty() const {
return end_ == begin_; return end_ == begin_;
} }
size_t capacity() const { int capacity() const {
return end_of_storage_ - begin_; return end_of_storage_ - begin_;
} }
@@ -301,21 +301,21 @@ class SmallVector {
return end_[-1]; return end_[-1];
} }
T& at(size_t index) { T& at(int index) {
if (index >= size()) { if (index >= size()) {
throw std::out_of_range("SmallVector out of range."); throw std::out_of_range("SmallVector out of range.");
} }
return begin_[index]; return begin_[index];
} }
const T& at(size_t index) const { const T& at(int index) const {
return const_cast<SmallVector*>(this)->at(index); return const_cast<SmallVector*>(this)->at(index);
} }
T& operator[](size_t index) { T& operator[](int index) {
assert(size() > index); assert(size() > index);
return begin_[index]; return begin_[index];
} }
const T& operator[](size_t index) const { const T& operator[](int index) const {
return const_cast<SmallVector*>(this)->operator[](index); return const_cast<SmallVector*>(this)->operator[](index);
} }
@@ -333,7 +333,7 @@ class SmallVector {
emplace_back(std::move(x)); emplace_back(std::move(x));
} }
void pop_back(size_t count = 1) { void pop_back(int count = 1) {
assert(size() >= count); assert(size() >= count);
end_ -= count; end_ -= count;
std::destroy_n(end_, count); std::destroy_n(end_, count);
@@ -400,7 +400,7 @@ class SmallVector {
return erase(pos, pos + 1); return erase(pos, pos + 1);
} }
void resize(size_t new_size) { void resize(int new_size) {
if (new_size > capacity()) { if (new_size > capacity()) {
grow(new_size); grow(new_size);
} }
@@ -415,7 +415,7 @@ class SmallVector {
end_ = new_end; end_ = new_end;
} }
void resize(size_t new_size, const T& initial_value) { void resize(int new_size, const T& initial_value) {
if (new_size > capacity()) { if (new_size > capacity()) {
grow(new_size); grow(new_size);
} }
@@ -428,7 +428,7 @@ class SmallVector {
end_ = new_end; end_ = new_end;
} }
void reserve(size_t new_capacity) { void reserve(int new_capacity) {
if (new_capacity > capacity()) { if (new_capacity > capacity()) {
grow(new_capacity); grow(new_capacity);
} }
@@ -443,8 +443,8 @@ class SmallVector {
private: private:
// Grows the backing store by a factor of two, and at least to {min_capacity}. // Grows the backing store by a factor of two, and at least to {min_capacity}.
// TODO: Move to private after removing external code using this method. // TODO: Move to private after removing external code using this method.
MLX_NOINLINE void grow(size_t min_capacity = 0) { MLX_NOINLINE void grow(int min_capacity = 0) {
size_t new_capacity = std::max(min_capacity, 2 * capacity()); int new_capacity = std::max(min_capacity, 2 * capacity());
// Round up to power of 2. // Round up to power of 2.
new_capacity--; new_capacity--;
new_capacity |= new_capacity >> 1; new_capacity |= new_capacity >> 1;
@@ -452,9 +452,6 @@ class SmallVector {
new_capacity |= new_capacity >> 4; new_capacity |= new_capacity >> 4;
new_capacity |= new_capacity >> 8; new_capacity |= new_capacity >> 8;
new_capacity |= new_capacity >> 16; new_capacity |= new_capacity >> 16;
if constexpr (sizeof(size_t) == sizeof(uint64_t)) {
new_capacity |= new_capacity >> 32;
}
new_capacity++; new_capacity++;
T* new_storage = allocator_.allocate(new_capacity); T* new_storage = allocator_.allocate(new_capacity);