From b0d985416aed3de70bc9fa63eaf08b6f9ee029eb Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 31 Oct 2025 13:13:15 -0700 Subject: [PATCH] fix arg_reduce --- mlx/backend/cpu/arg_reduce.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/backend/cpu/arg_reduce.cpp b/mlx/backend/cpu/arg_reduce.cpp index 41ab9fb60..3f42e7183 100644 --- a/mlx/backend/cpu/arg_reduce.cpp +++ b/mlx/backend/cpu/arg_reduce.cpp @@ -17,12 +17,12 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) { Strides strides = remove_index(in.strides(), axis); Shape shape = remove_index(in.shape(), axis); auto in_ptr = in.data(); - auto out_ptr = out.data(); + auto out_ptr = out.data(); for (int64_t i = 0; i < out.size(); ++i) { auto loc = elem_to_loc(i, shape, strides); auto local_in_ptr = in_ptr + loc; - int64_t ind_v = 0; + uint32_t ind_v = 0; InT v = (*local_in_ptr); for (int64_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) { op(j, (*local_in_ptr), &ind_v, &v);