mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
SmallVector: keep sizes small (int)
This commit is contained in:
@@ -121,10 +121,10 @@ class SmallVector {
|
|||||||
std::initializer_list<T> init,
|
std::initializer_list<T> init,
|
||||||
const Allocator& allocator = Allocator())
|
const Allocator& allocator = Allocator())
|
||||||
: allocator_(allocator) {
|
: allocator_(allocator) {
|
||||||
if (init.size() > capacity()) {
|
if (static_cast<int>(init.size()) > capacity()) {
|
||||||
grow(init.size());
|
grow(init.size());
|
||||||
}
|
}
|
||||||
assert(capacity() >= init.size()); // sanity check
|
assert(capacity() >= static_cast<int>(init.size())); // sanity check
|
||||||
std::uninitialized_move(init.begin(), init.end(), begin_);
|
std::uninitialized_move(init.begin(), init.end(), begin_);
|
||||||
end_ = begin_ + init.size();
|
end_ = begin_ + init.size();
|
||||||
}
|
}
|
||||||
@@ -132,7 +132,7 @@ class SmallVector {
|
|||||||
template <typename Iter, typename = std::enable_if_t<is_iterator_v<Iter>>>
|
template <typename Iter, typename = std::enable_if_t<is_iterator_v<Iter>>>
|
||||||
SmallVector(Iter begin, Iter end, const Allocator& allocator = Allocator())
|
SmallVector(Iter begin, Iter end, const Allocator& allocator = Allocator())
|
||||||
: allocator_(allocator) {
|
: allocator_(allocator) {
|
||||||
size_t size = std::distance(begin, end);
|
int size = std::distance(begin, end);
|
||||||
if (size > capacity()) {
|
if (size > capacity()) {
|
||||||
grow(size);
|
grow(size);
|
||||||
}
|
}
|
||||||
@@ -164,7 +164,7 @@ class SmallVector {
|
|||||||
if (this == &other) {
|
if (this == &other) {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
size_t other_size = other.size();
|
int other_size = other.size();
|
||||||
if (capacity() < other_size) {
|
if (capacity() < other_size) {
|
||||||
// Create large-enough heap-allocated storage.
|
// Create large-enough heap-allocated storage.
|
||||||
free_storage();
|
free_storage();
|
||||||
@@ -273,13 +273,13 @@ class SmallVector {
|
|||||||
return std::make_reverse_iterator(begin_);
|
return std::make_reverse_iterator(begin_);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t size() const {
|
int size() const {
|
||||||
return end_ - begin_;
|
return end_ - begin_;
|
||||||
}
|
}
|
||||||
bool empty() const {
|
bool empty() const {
|
||||||
return end_ == begin_;
|
return end_ == begin_;
|
||||||
}
|
}
|
||||||
size_t capacity() const {
|
int capacity() const {
|
||||||
return end_of_storage_ - begin_;
|
return end_of_storage_ - begin_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -301,21 +301,21 @@ class SmallVector {
|
|||||||
return end_[-1];
|
return end_[-1];
|
||||||
}
|
}
|
||||||
|
|
||||||
T& at(size_t index) {
|
T& at(int index) {
|
||||||
if (index >= size()) {
|
if (index >= size()) {
|
||||||
throw std::out_of_range("SmallVector out of range.");
|
throw std::out_of_range("SmallVector out of range.");
|
||||||
}
|
}
|
||||||
return begin_[index];
|
return begin_[index];
|
||||||
}
|
}
|
||||||
const T& at(size_t index) const {
|
const T& at(int index) const {
|
||||||
return const_cast<SmallVector*>(this)->at(index);
|
return const_cast<SmallVector*>(this)->at(index);
|
||||||
}
|
}
|
||||||
|
|
||||||
T& operator[](size_t index) {
|
T& operator[](int index) {
|
||||||
assert(size() > index);
|
assert(size() > index);
|
||||||
return begin_[index];
|
return begin_[index];
|
||||||
}
|
}
|
||||||
const T& operator[](size_t index) const {
|
const T& operator[](int index) const {
|
||||||
return const_cast<SmallVector*>(this)->operator[](index);
|
return const_cast<SmallVector*>(this)->operator[](index);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -333,7 +333,7 @@ class SmallVector {
|
|||||||
emplace_back(std::move(x));
|
emplace_back(std::move(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
void pop_back(size_t count = 1) {
|
void pop_back(int count = 1) {
|
||||||
assert(size() >= count);
|
assert(size() >= count);
|
||||||
end_ -= count;
|
end_ -= count;
|
||||||
std::destroy_n(end_, count);
|
std::destroy_n(end_, count);
|
||||||
@@ -400,7 +400,7 @@ class SmallVector {
|
|||||||
return erase(pos, pos + 1);
|
return erase(pos, pos + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void resize(size_t new_size) {
|
void resize(int new_size) {
|
||||||
if (new_size > capacity()) {
|
if (new_size > capacity()) {
|
||||||
grow(new_size);
|
grow(new_size);
|
||||||
}
|
}
|
||||||
@@ -415,7 +415,7 @@ class SmallVector {
|
|||||||
end_ = new_end;
|
end_ = new_end;
|
||||||
}
|
}
|
||||||
|
|
||||||
void resize(size_t new_size, const T& initial_value) {
|
void resize(int new_size, const T& initial_value) {
|
||||||
if (new_size > capacity()) {
|
if (new_size > capacity()) {
|
||||||
grow(new_size);
|
grow(new_size);
|
||||||
}
|
}
|
||||||
@@ -428,7 +428,7 @@ class SmallVector {
|
|||||||
end_ = new_end;
|
end_ = new_end;
|
||||||
}
|
}
|
||||||
|
|
||||||
void reserve(size_t new_capacity) {
|
void reserve(int new_capacity) {
|
||||||
if (new_capacity > capacity()) {
|
if (new_capacity > capacity()) {
|
||||||
grow(new_capacity);
|
grow(new_capacity);
|
||||||
}
|
}
|
||||||
@@ -443,8 +443,8 @@ class SmallVector {
|
|||||||
private:
|
private:
|
||||||
// Grows the backing store by a factor of two, and at least to {min_capacity}.
|
// Grows the backing store by a factor of two, and at least to {min_capacity}.
|
||||||
// TODO: Move to private after removing external code using this method.
|
// TODO: Move to private after removing external code using this method.
|
||||||
MLX_NOINLINE void grow(size_t min_capacity = 0) {
|
MLX_NOINLINE void grow(int min_capacity = 0) {
|
||||||
size_t new_capacity = std::max(min_capacity, 2 * capacity());
|
int new_capacity = std::max(min_capacity, 2 * capacity());
|
||||||
// Round up to power of 2.
|
// Round up to power of 2.
|
||||||
new_capacity--;
|
new_capacity--;
|
||||||
new_capacity |= new_capacity >> 1;
|
new_capacity |= new_capacity >> 1;
|
||||||
@@ -452,9 +452,6 @@ class SmallVector {
|
|||||||
new_capacity |= new_capacity >> 4;
|
new_capacity |= new_capacity >> 4;
|
||||||
new_capacity |= new_capacity >> 8;
|
new_capacity |= new_capacity >> 8;
|
||||||
new_capacity |= new_capacity >> 16;
|
new_capacity |= new_capacity >> 16;
|
||||||
if constexpr (sizeof(size_t) == sizeof(uint64_t)) {
|
|
||||||
new_capacity |= new_capacity >> 32;
|
|
||||||
}
|
|
||||||
new_capacity++;
|
new_capacity++;
|
||||||
|
|
||||||
T* new_storage = allocator_.allocate(new_capacity);
|
T* new_storage = allocator_.allocate(new_capacity);
|
||||||
|
|||||||
Reference in New Issue
Block a user