faster cpu indexing (#1450)

This commit is contained in:
Awni Hannun
2024-10-03 13:53:47 -07:00
committed by GitHub
parent d878015228
commit 5523d9c426
2 changed files with 58 additions and 14 deletions

View File

@@ -88,7 +88,11 @@ std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
template <typename StrideT>
struct ContiguousIterator {
inline void step() {
int i = dims_;
int dims = shape_.size();
if (dims == 0) {
return;
}
int i = dims - 1;
while (pos_[i] == (shape_[i] - 1) && i > 0) {
pos_[i] = 0;
loc -= (shape_[i] - 1) * strides_[i];
@@ -98,15 +102,41 @@ struct ContiguousIterator {
loc += strides_[i];
}
void seek(StrideT n) {
loc = 0;
for (int i = shape_.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(n, shape_[i]);
loc += q_and_r.rem * strides_[i];
pos_[i] = q_and_r.rem;
n = q_and_r.quot;
}
}
void reset() {
loc = 0;
std::fill(pos_.begin(), pos_.end(), 0);
}
ContiguousIterator() {};
explicit ContiguousIterator(const array& a)
: shape_(a.shape()), strides_(a.strides()) {
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0);
}
}
explicit ContiguousIterator(
const std::vector<int>& shape,
const std::vector<StrideT>& strides,
int dims)
: shape_(shape.begin(), shape.begin() + dims),
strides_(strides.begin(), strides.begin() + dims) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
dims_ = shape_.size() - 1;
pos_ = std::vector<int>(dims_ + 1, 0);
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0);
}
}
StrideT loc{0};
@@ -115,7 +145,6 @@ struct ContiguousIterator {
std::vector<int> shape_;
std::vector<StrideT> strides_;
std::vector<int> pos_;
int dims_;
};
template <typename StrideT>