Add remove_index utility

This commit is contained in:
Cheng 2025-05-11 11:16:48 +09:00
parent 659a51919f
commit 17bce1c092
4 changed files with 26 additions and 44 deletions

View File

@ -165,4 +165,11 @@ void shared_buffer_reshape(
const array& in, const array& in,
const Strides& out_strides, const Strides& out_strides,
array& out); array& out);
template <typename T>
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index));
return vec;
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -14,10 +14,8 @@ template <typename InT, typename OpT>
void arg_reduce(const array& in, array& out, const OpT& op, int axis) { void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto axis_size = in.shape()[axis]; auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis]; auto axis_stride = in.strides()[axis];
Strides strides = in.strides(); Strides strides = remove_index(in.strides(), axis);
Shape shape = in.shape(); Shape shape = remove_index(in.shape(), axis);
strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis);
auto in_ptr = in.data<InT>(); auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>(); auto out_ptr = out.data<uint32_t>();

View File

@ -257,15 +257,11 @@ void gather_axis(
const array& ind, const array& ind,
array& out, array& out,
const int axis) { const int axis) {
auto strides = ind.strides(); auto shape = remove_index(ind.shape(), axis);
strides.erase(strides.begin() + axis); ContiguousIterator ind_it(
auto shape = ind.shape(); shape, remove_index(ind.strides(), axis), src.ndim() - 1);
shape.erase(shape.begin() + axis); ContiguousIterator src_it(
ContiguousIterator ind_it(shape, strides, src.ndim() - 1); shape, remove_index(src.strides(), axis), src.ndim() - 1);
strides = src.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
auto ind_ptr = ind.data<IdxT>(); auto ind_ptr = ind.data<IdxT>();
auto src_ptr = src.data<T>(); auto src_ptr = src.data<T>();
@ -585,15 +581,11 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
template <typename T, typename IdxT, typename OpT> template <typename T, typename IdxT, typename OpT>
void scatter_axis(array& out, const array idx, const array& upd, int axis) { void scatter_axis(array& out, const array idx, const array& upd, int axis) {
auto strides = idx.strides(); auto shape = remove_index(idx.shape(), axis);
strides.erase(strides.begin() + axis); ContiguousIterator idx_it(
auto shape = idx.shape(); shape, remove_index(idx.strides(), axis), upd.ndim() - 1);
shape.erase(shape.begin() + axis); ContiguousIterator upd_it(
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1); shape, remove_index(upd.strides(), axis), upd.ndim() - 1);
strides = upd.strides();
strides.erase(strides.begin() + axis);
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
auto idx_ptr = idx.data<IdxT>(); auto idx_ptr = idx.data<IdxT>();
auto upd_ptr = upd.data<T>(); auto upd_ptr = upd.data<T>();

View File

@ -2,6 +2,7 @@
#include <fmt/format.h> #include <fmt/format.h>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/includes.h"
@ -458,17 +459,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Set source info // Set source info
auto shape = idx.shape(); compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
shape.erase(shape.begin() + axis_); compute_encoder.set_vector_bytes(remove_index(src.strides(), axis_), 4);
compute_encoder.set_vector_bytes(shape, 3); compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
auto strides = src.strides();
strides.erase(strides.begin() + axis_);
compute_encoder.set_vector_bytes(strides, 4);
strides = idx.strides();
strides.erase(strides.begin() + axis_);
compute_encoder.set_vector_bytes(strides, 5);
compute_encoder.set_bytes(ndim - 1, 6); compute_encoder.set_bytes(ndim - 1, 6);
compute_encoder.set_bytes(axis_, 7); compute_encoder.set_bytes(axis_, 7);
compute_encoder.set_bytes(src.shape(axis_), 8); compute_encoder.set_bytes(src.shape(axis_), 8);
@ -582,17 +575,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Set source info // Set source info
auto shape = idx.shape(); compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
shape.erase(shape.begin() + axis_); compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
compute_encoder.set_vector_bytes(shape, 3); compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
auto strides = upd.strides();
strides.erase(strides.begin() + axis_);
compute_encoder.set_vector_bytes(strides, 4);
strides = idx.strides();
strides.erase(strides.begin() + axis_);
compute_encoder.set_vector_bytes(strides, 5);
compute_encoder.set_bytes(ndim - 1, 6); compute_encoder.set_bytes(ndim - 1, 6);
compute_encoder.set_bytes(axis_, 7); compute_encoder.set_bytes(axis_, 7);
compute_encoder.set_bytes(out.shape(axis_), 8); compute_encoder.set_bytes(out.shape(axis_), 8);