mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
faster cpu indexing (#1450)
This commit is contained in:
@@ -88,7 +88,11 @@ std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
||||
template <typename StrideT>
|
||||
struct ContiguousIterator {
|
||||
inline void step() {
|
||||
int i = dims_;
|
||||
int dims = shape_.size();
|
||||
if (dims == 0) {
|
||||
return;
|
||||
}
|
||||
int i = dims - 1;
|
||||
while (pos_[i] == (shape_[i] - 1) && i > 0) {
|
||||
pos_[i] = 0;
|
||||
loc -= (shape_[i] - 1) * strides_[i];
|
||||
@@ -98,15 +102,41 @@ struct ContiguousIterator {
|
||||
loc += strides_[i];
|
||||
}
|
||||
|
||||
void seek(StrideT n) {
|
||||
loc = 0;
|
||||
for (int i = shape_.size() - 1; i >= 0; --i) {
|
||||
auto q_and_r = ldiv(n, shape_[i]);
|
||||
loc += q_and_r.rem * strides_[i];
|
||||
pos_[i] = q_and_r.rem;
|
||||
n = q_and_r.quot;
|
||||
}
|
||||
}
|
||||
|
||||
void reset() {
|
||||
loc = 0;
|
||||
std::fill(pos_.begin(), pos_.end(), 0);
|
||||
}
|
||||
|
||||
ContiguousIterator() {};
|
||||
|
||||
explicit ContiguousIterator(const array& a)
|
||||
: shape_(a.shape()), strides_(a.strides()) {
|
||||
if (!shape_.empty()) {
|
||||
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
||||
pos_ = std::vector<int>(shape_.size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
explicit ContiguousIterator(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& strides,
|
||||
int dims)
|
||||
: shape_(shape.begin(), shape.begin() + dims),
|
||||
strides_(strides.begin(), strides.begin() + dims) {
|
||||
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
||||
dims_ = shape_.size() - 1;
|
||||
pos_ = std::vector<int>(dims_ + 1, 0);
|
||||
if (!shape_.empty()) {
|
||||
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
||||
pos_ = std::vector<int>(shape_.size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
StrideT loc{0};
|
||||
@@ -115,7 +145,6 @@ struct ContiguousIterator {
|
||||
std::vector<int> shape_;
|
||||
std::vector<StrideT> strides_;
|
||||
std::vector<int> pos_;
|
||||
int dims_;
|
||||
};
|
||||
|
||||
template <typename StrideT>
|
||||
|
||||
Reference in New Issue
Block a user