mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
faster cpu indexing (#1450)
This commit is contained in:
parent
d878015228
commit
5523d9c426
@ -1,5 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@ -81,11 +80,18 @@ void gather(
|
||||
T* dst_ptr = out.data<T>();
|
||||
size_t out_idx = 0;
|
||||
|
||||
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
|
||||
ContiguousIterator<size_t> src_it;
|
||||
if (!can_copy && src.ndim() > 0) {
|
||||
src_it = std::move(
|
||||
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
|
||||
}
|
||||
for (int idx = 0; idx < ind_size; idx++) {
|
||||
size_t src_idx = 0;
|
||||
for (int ii = 0; ii < inds.size(); ++ii) {
|
||||
auto ax = axes[ii];
|
||||
auto idx_loc = elem_to_loc(idx, inds[ii]);
|
||||
auto idx_loc = its[ii].loc;
|
||||
its[ii].step();
|
||||
auto idx_val =
|
||||
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
|
||||
src_idx += (idx_val * src.strides()[ax]);
|
||||
@ -99,9 +105,10 @@ void gather(
|
||||
out_idx += slice_size;
|
||||
} else {
|
||||
for (int jj = 0; jj < slice_size; jj++) {
|
||||
auto src_offset = elem_to_loc(jj, slice_sizes, src.strides());
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx + src_offset];
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
|
||||
src_it.step();
|
||||
}
|
||||
src_it.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -223,21 +230,29 @@ void scatter(
|
||||
update_size *= us;
|
||||
}
|
||||
|
||||
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
|
||||
ContiguousIterator<size_t> update_it(updates);
|
||||
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
|
||||
|
||||
for (int i = 0; i < n_updates; ++i) {
|
||||
size_t out_offset = 0;
|
||||
for (int j = 0; j < nind; ++j) {
|
||||
auto ax = axes[j];
|
||||
auto idx_loc = elem_to_loc(i, inds[j]);
|
||||
auto idx_loc = its[j].loc;
|
||||
its[j].step();
|
||||
auto idx_val =
|
||||
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
|
||||
out_offset += (idx_val * out.strides()[ax]);
|
||||
}
|
||||
update_it.seek(i * update_size);
|
||||
for (int j = 0; j < update_size; ++j) {
|
||||
auto update_loc = elem_to_loc(i * update_size + j, updates);
|
||||
auto out_loc = elem_to_loc(j, update_shape, out.strides());
|
||||
op(updates.data<InT>()[update_loc],
|
||||
out.data<InT>() + out_offset + out_loc);
|
||||
op(updates.data<InT>()[update_it.loc],
|
||||
out.data<InT>() + out_offset + out_it.loc);
|
||||
update_it.step();
|
||||
out_it.step();
|
||||
}
|
||||
out_it.reset();
|
||||
update_it.reset();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -88,7 +88,11 @@ std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
||||
template <typename StrideT>
|
||||
struct ContiguousIterator {
|
||||
inline void step() {
|
||||
int i = dims_;
|
||||
int dims = shape_.size();
|
||||
if (dims == 0) {
|
||||
return;
|
||||
}
|
||||
int i = dims - 1;
|
||||
while (pos_[i] == (shape_[i] - 1) && i > 0) {
|
||||
pos_[i] = 0;
|
||||
loc -= (shape_[i] - 1) * strides_[i];
|
||||
@ -98,15 +102,41 @@ struct ContiguousIterator {
|
||||
loc += strides_[i];
|
||||
}
|
||||
|
||||
void seek(StrideT n) {
|
||||
loc = 0;
|
||||
for (int i = shape_.size() - 1; i >= 0; --i) {
|
||||
auto q_and_r = ldiv(n, shape_[i]);
|
||||
loc += q_and_r.rem * strides_[i];
|
||||
pos_[i] = q_and_r.rem;
|
||||
n = q_and_r.quot;
|
||||
}
|
||||
}
|
||||
|
||||
void reset() {
|
||||
loc = 0;
|
||||
std::fill(pos_.begin(), pos_.end(), 0);
|
||||
}
|
||||
|
||||
ContiguousIterator() {};
|
||||
|
||||
explicit ContiguousIterator(const array& a)
|
||||
: shape_(a.shape()), strides_(a.strides()) {
|
||||
if (!shape_.empty()) {
|
||||
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
||||
pos_ = std::vector<int>(shape_.size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
explicit ContiguousIterator(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<StrideT>& strides,
|
||||
int dims)
|
||||
: shape_(shape.begin(), shape.begin() + dims),
|
||||
strides_(strides.begin(), strides.begin() + dims) {
|
||||
if (!shape_.empty()) {
|
||||
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
||||
dims_ = shape_.size() - 1;
|
||||
pos_ = std::vector<int>(dims_ + 1, 0);
|
||||
pos_ = std::vector<int>(shape_.size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
StrideT loc{0};
|
||||
@ -115,7 +145,6 @@ struct ContiguousIterator {
|
||||
std::vector<int> shape_;
|
||||
std::vector<StrideT> strides_;
|
||||
std::vector<int> pos_;
|
||||
int dims_;
|
||||
};
|
||||
|
||||
template <typename StrideT>
|
||||
|
Loading…
Reference in New Issue
Block a user