From 17bce1c09279f4eb43c6f895a0cedf33d005e2bc Mon Sep 17 00:00:00 2001 From: Cheng Date: Sun, 11 May 2025 11:16:48 +0900 Subject: [PATCH] Add remove_index utility --- mlx/backend/common/utils.h | 7 +++++++ mlx/backend/cpu/arg_reduce.cpp | 6 ++---- mlx/backend/cpu/indexing.cpp | 28 ++++++++++------------------ mlx/backend/metal/indexing.cpp | 29 +++++++---------------------- 4 files changed, 26 insertions(+), 44 deletions(-) diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 20a65d7b1..a4bdaa5ca 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -165,4 +165,11 @@ void shared_buffer_reshape( const array& in, const Strides& out_strides, array& out); + +template +inline std::vector remove_index(std::vector vec, size_t index) { + vec.erase(std::next(vec.begin(), index)); + return vec; +} + } // namespace mlx::core diff --git a/mlx/backend/cpu/arg_reduce.cpp b/mlx/backend/cpu/arg_reduce.cpp index a8ba3efe2..66468912d 100644 --- a/mlx/backend/cpu/arg_reduce.cpp +++ b/mlx/backend/cpu/arg_reduce.cpp @@ -14,10 +14,8 @@ template void arg_reduce(const array& in, array& out, const OpT& op, int axis) { auto axis_size = in.shape()[axis]; auto axis_stride = in.strides()[axis]; - Strides strides = in.strides(); - Shape shape = in.shape(); - strides.erase(strides.begin() + axis); - shape.erase(shape.begin() + axis); + Strides strides = remove_index(in.strides(), axis); + Shape shape = remove_index(in.shape(), axis); auto in_ptr = in.data(); auto out_ptr = out.data(); diff --git a/mlx/backend/cpu/indexing.cpp b/mlx/backend/cpu/indexing.cpp index 70d6b3eb7..5f99093e5 100644 --- a/mlx/backend/cpu/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -257,15 +257,11 @@ void gather_axis( const array& ind, array& out, const int axis) { - auto strides = ind.strides(); - strides.erase(strides.begin() + axis); - auto shape = ind.shape(); - shape.erase(shape.begin() + axis); - ContiguousIterator ind_it(shape, strides, src.ndim() - 1); - - strides = src.strides(); - strides.erase(strides.begin() + axis); - ContiguousIterator src_it(shape, strides, src.ndim() - 1); + auto shape = remove_index(ind.shape(), axis); + ContiguousIterator ind_it( + shape, remove_index(ind.strides(), axis), src.ndim() - 1); + ContiguousIterator src_it( + shape, remove_index(src.strides(), axis), src.ndim() - 1); auto ind_ptr = ind.data(); auto src_ptr = src.data(); @@ -585,15 +581,11 @@ void Scatter::eval_cpu(const std::vector& inputs, array& out) { template void scatter_axis(array& out, const array idx, const array& upd, int axis) { - auto strides = idx.strides(); - strides.erase(strides.begin() + axis); - auto shape = idx.shape(); - shape.erase(shape.begin() + axis); - ContiguousIterator idx_it(shape, strides, upd.ndim() - 1); - - strides = upd.strides(); - strides.erase(strides.begin() + axis); - ContiguousIterator upd_it(shape, strides, upd.ndim() - 1); + auto shape = remove_index(idx.shape(), axis); + ContiguousIterator idx_it( + shape, remove_index(idx.strides(), axis), upd.ndim() - 1); + ContiguousIterator upd_it( + shape, remove_index(upd.strides(), axis), upd.ndim() - 1); auto idx_ptr = idx.data(); auto upd_ptr = upd.data(); diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index cccfd908a..d2a601b1e 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -2,6 +2,7 @@ #include #include "mlx/backend/common/compiled.h" +#include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/jit/includes.h" @@ -458,17 +459,9 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 2); // Set source info - auto shape = idx.shape(); - shape.erase(shape.begin() + axis_); - compute_encoder.set_vector_bytes(shape, 3); - - auto strides = src.strides(); - strides.erase(strides.begin() + axis_); - compute_encoder.set_vector_bytes(strides, 4); - - strides = idx.strides(); - strides.erase(strides.begin() + axis_); - compute_encoder.set_vector_bytes(strides, 5); + compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); + compute_encoder.set_vector_bytes(remove_index(src.strides(), axis_), 4); + compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); compute_encoder.set_bytes(ndim - 1, 6); compute_encoder.set_bytes(axis_, 7); compute_encoder.set_bytes(src.shape(axis_), 8); @@ -582,17 +575,9 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 2); // Set source info - auto shape = idx.shape(); - shape.erase(shape.begin() + axis_); - compute_encoder.set_vector_bytes(shape, 3); - - auto strides = upd.strides(); - strides.erase(strides.begin() + axis_); - compute_encoder.set_vector_bytes(strides, 4); - - strides = idx.strides(); - strides.erase(strides.begin() + axis_); - compute_encoder.set_vector_bytes(strides, 5); + compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3); + compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4); + compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5); compute_encoder.set_bytes(ndim - 1, 6); compute_encoder.set_bytes(axis_, 7); compute_encoder.set_bytes(out.shape(axis_), 8);