Modified sort behavior when running CPU or Metal to match NumPy/JAX (#2667)

* Modified sort behavior when running CPU or Metal to match NumPy/JAX sorting behavior.

* Modified sort behavior when running CPU or Metal to match NumPy/JAX

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Manuel Villanueva
2025-10-13 16:36:45 -05:00
committed by GitHub
parent 9bfc476d72
commit 9cbb1b0148
3 changed files with 58 additions and 7 deletions

View File

@@ -19,11 +19,28 @@ METAL_FUNC void thread_swap(thread T& a, thread T& b) {
b = w;
}
template <typename T, typename = void>
struct Init {
static constexpr constant T v = Limits<T>::max;
};
template <typename T>
struct Init<T, metal::enable_if_t<metal::is_floating_point_v<T>>> {
static constexpr constant T v = metal::numeric_limits<T>::quiet_NaN();
};
template <typename T>
struct LessThan {
static constexpr constant T init = Limits<T>::max;
METAL_FUNC bool operator()(T a, T b) {
static constexpr constant T init = Init<T>::v;
METAL_FUNC bool operator()(T a, T b) const {
if constexpr (
metal::is_floating_point_v<T> || metal::is_same_v<T, complex64_t>) {
bool an = isnan(a);
bool bn = isnan(b);
if (an | bn) {
return (!an) & bn;
}
}
return a < b;
}
};