Add remove_index utility (#2173)

This commit is contained in:
Cheng
2025-05-14 09:09:56 +09:00
committed by GitHub
parent 3aa9cf3f9e
commit eca2f3eb97
4 changed files with 26 additions and 44 deletions

View File

@@ -14,10 +14,8 @@ template <typename InT, typename OpT>
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis];
Strides strides = in.strides();
Shape shape = in.shape();
strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis);
Strides strides = remove_index(in.strides(), axis);
Shape shape = remove_index(in.shape(), axis);
auto in_ptr = in.data<InT>();
auto out_ptr = out.data<uint32_t>();