mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-21 18:28:11 +08:00
Modified sort behavior when running CPU or Metal to match NumPy/JAX (#2667)
* Modified sort behavior when running CPU or Metal to match NumPy/JAX sorting behavior. * Modified sort behavior when running CPU or Metal to match NumPy/JAX * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:

committed by
GitHub

parent
9bfc476d72
commit
9cbb1b0148
@@ -19,11 +19,28 @@ METAL_FUNC void thread_swap(thread T& a, thread T& b) {
|
||||
b = w;
|
||||
}
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct Init {
|
||||
static constexpr constant T v = Limits<T>::max;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Init<T, metal::enable_if_t<metal::is_floating_point_v<T>>> {
|
||||
static constexpr constant T v = metal::numeric_limits<T>::quiet_NaN();
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LessThan {
|
||||
static constexpr constant T init = Limits<T>::max;
|
||||
|
||||
METAL_FUNC bool operator()(T a, T b) {
|
||||
static constexpr constant T init = Init<T>::v;
|
||||
METAL_FUNC bool operator()(T a, T b) const {
|
||||
if constexpr (
|
||||
metal::is_floating_point_v<T> || metal::is_same_v<T, complex64_t>) {
|
||||
bool an = isnan(a);
|
||||
bool bn = isnan(b);
|
||||
if (an | bn) {
|
||||
return (!an) & bn;
|
||||
}
|
||||
}
|
||||
return a < b;
|
||||
}
|
||||
};
|
||||
|
Reference in New Issue
Block a user