diff --git a/mlx/backend/cpu/arg_reduce.cpp b/mlx/backend/cpu/arg_reduce.cpp index 41ab9fb60..3f42e7183 100644 --- a/mlx/backend/cpu/arg_reduce.cpp +++ b/mlx/backend/cpu/arg_reduce.cpp @@ -17,12 +17,12 @@ void arg_reduce(const array& in, array& out, const OpT& op, int axis) { Strides strides = remove_index(in.strides(), axis); Shape shape = remove_index(in.shape(), axis); auto in_ptr = in.data(); - auto out_ptr = out.data(); + auto out_ptr = out.data(); for (int64_t i = 0; i < out.size(); ++i) { auto loc = elem_to_loc(i, shape, strides); auto local_in_ptr = in_ptr + loc; - int64_t ind_v = 0; + uint32_t ind_v = 0; InT v = (*local_in_ptr); for (int64_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) { op(j, (*local_in_ptr), &ind_v, &v);