From 73f22d622633acf3dcd2a2e043fff1947916db13 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 17 Aug 2025 08:42:20 -0700 Subject: [PATCH] Ensure small sort doesn't use indices if not argsort (#2506) --- mlx/backend/metal/kernels/sort.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mlx/backend/metal/kernels/sort.h b/mlx/backend/metal/kernels/sort.h index b067150d8..5823e4300 100644 --- a/mlx/backend/metal/kernels/sort.h +++ b/mlx/backend/metal/kernels/sort.h @@ -45,7 +45,9 @@ struct ThreadSort { for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) { if (op(vals[j + 1], vals[j])) { thread_swap(vals[j + 1], vals[j]); - thread_swap(idxs[j + 1], idxs[j]); + if (ARG_SORT) { + thread_swap(idxs[j + 1], idxs[j]); + } } } } @@ -111,7 +113,9 @@ struct BlockMergeSort { bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a)); vals[i] = pred ? b : a; - idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx]; + if (ARG_SORT) { + idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx]; + } b_idx += short(pred); a_idx += short(!pred);