faster cpu indexing (#1450)

This commit is contained in:
Awni Hannun 2024-10-03 13:53:47 -07:00 committed by GitHub
parent d878015228
commit 5523d9c426
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 14 deletions

View File

@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
@ -81,11 +80,18 @@ void gather(
T* dst_ptr = out.data<T>(); T* dst_ptr = out.data<T>();
size_t out_idx = 0; size_t out_idx = 0;
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> src_it;
if (!can_copy && src.ndim() > 0) {
src_it = std::move(
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
}
for (int idx = 0; idx < ind_size; idx++) { for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0; size_t src_idx = 0;
for (int ii = 0; ii < inds.size(); ++ii) { for (int ii = 0; ii < inds.size(); ++ii) {
auto ax = axes[ii]; auto ax = axes[ii];
auto idx_loc = elem_to_loc(idx, inds[ii]); auto idx_loc = its[ii].loc;
its[ii].step();
auto idx_val = auto idx_val =
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax)); offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
src_idx += (idx_val * src.strides()[ax]); src_idx += (idx_val * src.strides()[ax]);
@ -99,9 +105,10 @@ void gather(
out_idx += slice_size; out_idx += slice_size;
} else { } else {
for (int jj = 0; jj < slice_size; jj++) { for (int jj = 0; jj < slice_size; jj++) {
auto src_offset = elem_to_loc(jj, slice_sizes, src.strides()); dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
dst_ptr[out_idx++] = src_ptr[src_idx + src_offset]; src_it.step();
} }
src_it.reset();
} }
} }
} }
@ -223,21 +230,29 @@ void scatter(
update_size *= us; update_size *= us;
} }
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> update_it(updates);
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
for (int i = 0; i < n_updates; ++i) { for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0; size_t out_offset = 0;
for (int j = 0; j < nind; ++j) { for (int j = 0; j < nind; ++j) {
auto ax = axes[j]; auto ax = axes[j];
auto idx_loc = elem_to_loc(i, inds[j]); auto idx_loc = its[j].loc;
its[j].step();
auto idx_val = auto idx_val =
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax)); offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
out_offset += (idx_val * out.strides()[ax]); out_offset += (idx_val * out.strides()[ax]);
} }
update_it.seek(i * update_size);
for (int j = 0; j < update_size; ++j) { for (int j = 0; j < update_size; ++j) {
auto update_loc = elem_to_loc(i * update_size + j, updates); op(updates.data<InT>()[update_it.loc],
auto out_loc = elem_to_loc(j, update_shape, out.strides()); out.data<InT>() + out_offset + out_it.loc);
op(updates.data<InT>()[update_loc], update_it.step();
out.data<InT>() + out_offset + out_loc); out_it.step();
} }
out_it.reset();
update_it.reset();
} }
} }

View File

@ -88,7 +88,11 @@ std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
template <typename StrideT> template <typename StrideT>
struct ContiguousIterator { struct ContiguousIterator {
inline void step() { inline void step() {
int i = dims_; int dims = shape_.size();
if (dims == 0) {
return;
}
int i = dims - 1;
while (pos_[i] == (shape_[i] - 1) && i > 0) { while (pos_[i] == (shape_[i] - 1) && i > 0) {
pos_[i] = 0; pos_[i] = 0;
loc -= (shape_[i] - 1) * strides_[i]; loc -= (shape_[i] - 1) * strides_[i];
@ -98,15 +102,41 @@ struct ContiguousIterator {
loc += strides_[i]; loc += strides_[i];
} }
void seek(StrideT n) {
loc = 0;
for (int i = shape_.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(n, shape_[i]);
loc += q_and_r.rem * strides_[i];
pos_[i] = q_and_r.rem;
n = q_and_r.quot;
}
}
void reset() {
loc = 0;
std::fill(pos_.begin(), pos_.end(), 0);
}
ContiguousIterator() {};
explicit ContiguousIterator(const array& a)
: shape_(a.shape()), strides_(a.strides()) {
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0);
}
}
explicit ContiguousIterator( explicit ContiguousIterator(
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<StrideT>& strides, const std::vector<StrideT>& strides,
int dims) int dims)
: shape_(shape.begin(), shape.begin() + dims), : shape_(shape.begin(), shape.begin() + dims),
strides_(strides.begin(), strides.begin() + dims) { strides_(strides.begin(), strides.begin() + dims) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); if (!shape_.empty()) {
dims_ = shape_.size() - 1; std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(dims_ + 1, 0); pos_ = std::vector<int>(shape_.size(), 0);
}
} }
StrideT loc{0}; StrideT loc{0};
@ -115,7 +145,6 @@ struct ContiguousIterator {
std::vector<int> shape_; std::vector<int> shape_;
std::vector<StrideT> strides_; std::vector<StrideT> strides_;
std::vector<int> pos_; std::vector<int> pos_;
int dims_;
}; };
template <typename StrideT> template <typename StrideT>