fix arg_reduce

This commit is contained in:
Ronan Collobert
2025-10-31 13:13:15 -07:00
parent 8d10f3ec75
commit b0d985416a

View File

@@ -17,12 +17,12 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
Strides strides = remove_index(in.strides(), axis);
Shape shape = remove_index(in.shape(), axis);
auto in_ptr = in.data<InT>();
auto out_ptr = out.data<int64_t>();
auto out_ptr = out.data<uint32_t>();
for (int64_t i = 0; i < out.size(); ++i) {
auto loc = elem_to_loc(i, shape, strides);
auto local_in_ptr = in_ptr + loc;
int64_t ind_v = 0;
uint32_t ind_v = 0;
InT v = (*local_in_ptr);
for (int64_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
op(j, (*local_in_ptr), &ind_v, &v);