mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	faster cpu indexing (#1450)
This commit is contained in:
		| @@ -1,5 +1,4 @@ | ||||
| // Copyright © 2023 Apple Inc. | ||||
|  | ||||
| #include <algorithm> | ||||
| #include <cassert> | ||||
| #include <cmath> | ||||
| @@ -81,11 +80,18 @@ void gather( | ||||
|   T* dst_ptr = out.data<T>(); | ||||
|   size_t out_idx = 0; | ||||
|  | ||||
|   std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end()); | ||||
|   ContiguousIterator<size_t> src_it; | ||||
|   if (!can_copy && src.ndim() > 0) { | ||||
|     src_it = std::move( | ||||
|         ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim())); | ||||
|   } | ||||
|   for (int idx = 0; idx < ind_size; idx++) { | ||||
|     size_t src_idx = 0; | ||||
|     for (int ii = 0; ii < inds.size(); ++ii) { | ||||
|       auto ax = axes[ii]; | ||||
|       auto idx_loc = elem_to_loc(idx, inds[ii]); | ||||
|       auto idx_loc = its[ii].loc; | ||||
|       its[ii].step(); | ||||
|       auto idx_val = | ||||
|           offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax)); | ||||
|       src_idx += (idx_val * src.strides()[ax]); | ||||
| @@ -99,9 +105,10 @@ void gather( | ||||
|       out_idx += slice_size; | ||||
|     } else { | ||||
|       for (int jj = 0; jj < slice_size; jj++) { | ||||
|         auto src_offset = elem_to_loc(jj, slice_sizes, src.strides()); | ||||
|         dst_ptr[out_idx++] = src_ptr[src_idx + src_offset]; | ||||
|         dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc]; | ||||
|         src_it.step(); | ||||
|       } | ||||
|       src_it.reset(); | ||||
|     } | ||||
|   } | ||||
| } | ||||
| @@ -223,21 +230,29 @@ void scatter( | ||||
|     update_size *= us; | ||||
|   } | ||||
|  | ||||
|   std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end()); | ||||
|   ContiguousIterator<size_t> update_it(updates); | ||||
|   ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim()); | ||||
|  | ||||
|   for (int i = 0; i < n_updates; ++i) { | ||||
|     size_t out_offset = 0; | ||||
|     for (int j = 0; j < nind; ++j) { | ||||
|       auto ax = axes[j]; | ||||
|       auto idx_loc = elem_to_loc(i, inds[j]); | ||||
|       auto idx_loc = its[j].loc; | ||||
|       its[j].step(); | ||||
|       auto idx_val = | ||||
|           offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax)); | ||||
|       out_offset += (idx_val * out.strides()[ax]); | ||||
|     } | ||||
|     update_it.seek(i * update_size); | ||||
|     for (int j = 0; j < update_size; ++j) { | ||||
|       auto update_loc = elem_to_loc(i * update_size + j, updates); | ||||
|       auto out_loc = elem_to_loc(j, update_shape, out.strides()); | ||||
|       op(updates.data<InT>()[update_loc], | ||||
|          out.data<InT>() + out_offset + out_loc); | ||||
|       op(updates.data<InT>()[update_it.loc], | ||||
|          out.data<InT>() + out_offset + out_it.loc); | ||||
|       update_it.step(); | ||||
|       out_it.step(); | ||||
|     } | ||||
|     out_it.reset(); | ||||
|     update_it.reset(); | ||||
|   } | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -88,7 +88,11 @@ std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims( | ||||
| template <typename StrideT> | ||||
| struct ContiguousIterator { | ||||
|   inline void step() { | ||||
|     int i = dims_; | ||||
|     int dims = shape_.size(); | ||||
|     if (dims == 0) { | ||||
|       return; | ||||
|     } | ||||
|     int i = dims - 1; | ||||
|     while (pos_[i] == (shape_[i] - 1) && i > 0) { | ||||
|       pos_[i] = 0; | ||||
|       loc -= (shape_[i] - 1) * strides_[i]; | ||||
| @@ -98,15 +102,41 @@ struct ContiguousIterator { | ||||
|     loc += strides_[i]; | ||||
|   } | ||||
|  | ||||
|   void seek(StrideT n) { | ||||
|     loc = 0; | ||||
|     for (int i = shape_.size() - 1; i >= 0; --i) { | ||||
|       auto q_and_r = ldiv(n, shape_[i]); | ||||
|       loc += q_and_r.rem * strides_[i]; | ||||
|       pos_[i] = q_and_r.rem; | ||||
|       n = q_and_r.quot; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   void reset() { | ||||
|     loc = 0; | ||||
|     std::fill(pos_.begin(), pos_.end(), 0); | ||||
|   } | ||||
|  | ||||
|   ContiguousIterator() {}; | ||||
|  | ||||
|   explicit ContiguousIterator(const array& a) | ||||
|       : shape_(a.shape()), strides_(a.strides()) { | ||||
|     if (!shape_.empty()) { | ||||
|       std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); | ||||
|       pos_ = std::vector<int>(shape_.size(), 0); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   explicit ContiguousIterator( | ||||
|       const std::vector<int>& shape, | ||||
|       const std::vector<StrideT>& strides, | ||||
|       int dims) | ||||
|       : shape_(shape.begin(), shape.begin() + dims), | ||||
|         strides_(strides.begin(), strides.begin() + dims) { | ||||
|     std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); | ||||
|     dims_ = shape_.size() - 1; | ||||
|     pos_ = std::vector<int>(dims_ + 1, 0); | ||||
|     if (!shape_.empty()) { | ||||
|       std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); | ||||
|       pos_ = std::vector<int>(shape_.size(), 0); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   StrideT loc{0}; | ||||
| @@ -115,7 +145,6 @@ struct ContiguousIterator { | ||||
|   std::vector<int> shape_; | ||||
|   std::vector<StrideT> strides_; | ||||
|   std::vector<int> pos_; | ||||
|   int dims_; | ||||
| }; | ||||
|  | ||||
| template <typename StrideT> | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun