diff --git a/mlx/backend/metal/kernels/sort.h b/mlx/backend/metal/kernels/sort.h index b067150d8..5823e4300 100644 --- a/mlx/backend/metal/kernels/sort.h +++ b/mlx/backend/metal/kernels/sort.h @@ -45,7 +45,9 @@ struct ThreadSort { for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) { if (op(vals[j + 1], vals[j])) { thread_swap(vals[j + 1], vals[j]); - thread_swap(idxs[j + 1], idxs[j]); + if (ARG_SORT) { + thread_swap(idxs[j + 1], idxs[j]); + } } } } @@ -111,7 +113,9 @@ struct BlockMergeSort { bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a)); vals[i] = pred ? b : a; - idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx]; + if (ARG_SORT) { + idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx]; + } b_idx += short(pred); a_idx += short(!pred);