mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
Ensure small sort doesn't use indices if not argsort (#2506)
This commit is contained in:
parent
c422050ca7
commit
73f22d6226
@ -45,11 +45,13 @@ struct ThreadSort {
|
|||||||
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
|
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
|
||||||
if (op(vals[j + 1], vals[j])) {
|
if (op(vals[j + 1], vals[j])) {
|
||||||
thread_swap(vals[j + 1], vals[j]);
|
thread_swap(vals[j + 1], vals[j]);
|
||||||
|
if (ARG_SORT) {
|
||||||
thread_swap(idxs[j + 1], idxs[j]);
|
thread_swap(idxs[j + 1], idxs[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@ -111,7 +113,9 @@ struct BlockMergeSort {
|
|||||||
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
|
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
|
||||||
|
|
||||||
vals[i] = pred ? b : a;
|
vals[i] = pred ? b : a;
|
||||||
|
if (ARG_SORT) {
|
||||||
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
|
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
|
||||||
|
}
|
||||||
|
|
||||||
b_idx += short(pred);
|
b_idx += short(pred);
|
||||||
a_idx += short(!pred);
|
a_idx += short(!pred);
|
||||||
|
Loading…
Reference in New Issue
Block a user