From 5523d9c426cf62a4dc2a4d16e50a741fa39fdb59 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 3 Oct 2024 13:53:47 -0700 Subject: [PATCH] faster cpu indexing (#1450) --- mlx/backend/common/indexing.cpp | 33 ++++++++++++++++++++-------- mlx/backend/common/utils.h | 39 ++++++++++++++++++++++++++++----- 2 files changed, 58 insertions(+), 14 deletions(-) diff --git a/mlx/backend/common/indexing.cpp b/mlx/backend/common/indexing.cpp index 64fc21ec8..1bb3eb44f 100644 --- a/mlx/backend/common/indexing.cpp +++ b/mlx/backend/common/indexing.cpp @@ -1,5 +1,4 @@ // Copyright © 2023 Apple Inc. - #include #include #include @@ -81,11 +80,18 @@ void gather( T* dst_ptr = out.data(); size_t out_idx = 0; + std::vector> its(inds.begin(), inds.end()); + ContiguousIterator src_it; + if (!can_copy && src.ndim() > 0) { + src_it = std::move( + ContiguousIterator(slice_sizes, src.strides(), src.ndim())); + } for (int idx = 0; idx < ind_size; idx++) { size_t src_idx = 0; for (int ii = 0; ii < inds.size(); ++ii) { auto ax = axes[ii]; - auto idx_loc = elem_to_loc(idx, inds[ii]); + auto idx_loc = its[ii].loc; + its[ii].step(); auto idx_val = offset_neg_idx(inds[ii].data()[idx_loc], src.shape(ax)); src_idx += (idx_val * src.strides()[ax]); @@ -99,9 +105,10 @@ void gather( out_idx += slice_size; } else { for (int jj = 0; jj < slice_size; jj++) { - auto src_offset = elem_to_loc(jj, slice_sizes, src.strides()); - dst_ptr[out_idx++] = src_ptr[src_idx + src_offset]; + dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc]; + src_it.step(); } + src_it.reset(); } } } @@ -223,21 +230,29 @@ void scatter( update_size *= us; } + std::vector> its(inds.begin(), inds.end()); + ContiguousIterator update_it(updates); + ContiguousIterator out_it(update_shape, out.strides(), out.ndim()); + for (int i = 0; i < n_updates; ++i) { size_t out_offset = 0; for (int j = 0; j < nind; ++j) { auto ax = axes[j]; - auto idx_loc = elem_to_loc(i, inds[j]); + auto idx_loc = its[j].loc; + its[j].step(); auto idx_val = offset_neg_idx(inds[j].data()[idx_loc], out.shape(ax)); out_offset += (idx_val * out.strides()[ax]); } + update_it.seek(i * update_size); for (int j = 0; j < update_size; ++j) { - auto update_loc = elem_to_loc(i * update_size + j, updates); - auto out_loc = elem_to_loc(j, update_shape, out.strides()); - op(updates.data()[update_loc], - out.data() + out_offset + out_loc); + op(updates.data()[update_it.loc], + out.data() + out_offset + out_it.loc); + update_it.step(); + out_it.step(); } + out_it.reset(); + update_it.reset(); } } diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index b037c309f..7ce38b908 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -88,7 +88,11 @@ std::pair, std::vector> collapse_contiguous_dims( template struct ContiguousIterator { inline void step() { - int i = dims_; + int dims = shape_.size(); + if (dims == 0) { + return; + } + int i = dims - 1; while (pos_[i] == (shape_[i] - 1) && i > 0) { pos_[i] = 0; loc -= (shape_[i] - 1) * strides_[i]; @@ -98,15 +102,41 @@ struct ContiguousIterator { loc += strides_[i]; } + void seek(StrideT n) { + loc = 0; + for (int i = shape_.size() - 1; i >= 0; --i) { + auto q_and_r = ldiv(n, shape_[i]); + loc += q_and_r.rem * strides_[i]; + pos_[i] = q_and_r.rem; + n = q_and_r.quot; + } + } + + void reset() { + loc = 0; + std::fill(pos_.begin(), pos_.end(), 0); + } + + ContiguousIterator() {}; + + explicit ContiguousIterator(const array& a) + : shape_(a.shape()), strides_(a.strides()) { + if (!shape_.empty()) { + std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); + pos_ = std::vector(shape_.size(), 0); + } + } + explicit ContiguousIterator( const std::vector& shape, const std::vector& strides, int dims) : shape_(shape.begin(), shape.begin() + dims), strides_(strides.begin(), strides.begin() + dims) { - std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); - dims_ = shape_.size() - 1; - pos_ = std::vector(dims_ + 1, 0); + if (!shape_.empty()) { + std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); + pos_ = std::vector(shape_.size(), 0); + } } StrideT loc{0}; @@ -115,7 +145,6 @@ struct ContiguousIterator { std::vector shape_; std::vector strides_; std::vector pos_; - int dims_; }; template