mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Modified sort behavior when running CPU or Metal to match NumPy/JAX (#2667)
* Modified sort behavior when running CPU or Metal to match NumPy/JAX sorting behavior. * Modified sort behavior when running CPU or Metal to match NumPy/JAX * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:

committed by
GitHub

parent
9bfc476d72
commit
9cbb1b0148
@@ -15,6 +15,18 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// NaN-aware comparator that places NaNs at the end
|
||||||
|
template <typename T>
|
||||||
|
bool nan_aware_less(T a, T b) {
|
||||||
|
if constexpr (std::is_floating_point_v<T> || std::is_same_v<T, complex64_t>) {
|
||||||
|
if (std::isnan(a))
|
||||||
|
return false;
|
||||||
|
if (std::isnan(b))
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return a < b;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct StridedIterator {
|
struct StridedIterator {
|
||||||
using iterator_category = std::random_access_iterator_tag;
|
using iterator_category = std::random_access_iterator_tag;
|
||||||
@@ -130,7 +142,7 @@ void sort(array& out, int axis) {
|
|||||||
StridedIterator st(data_ptr, axis_stride, 0);
|
StridedIterator st(data_ptr, axis_stride, 0);
|
||||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||||
|
|
||||||
std::stable_sort(st, ed);
|
std::stable_sort(st, ed, nan_aware_less<T>);
|
||||||
src_it.step();
|
src_it.step();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -184,6 +196,15 @@ void argsort(const array& in, array& out, int axis) {
|
|||||||
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||||
auto v1 = data_ptr[a * in_stride];
|
auto v1 = data_ptr[a * in_stride];
|
||||||
auto v2 = data_ptr[b * in_stride];
|
auto v2 = data_ptr[b * in_stride];
|
||||||
|
|
||||||
|
// Handle NaNs (place them at the end)
|
||||||
|
if (std::is_floating_point<T>::value) {
|
||||||
|
if (std::isnan(v1))
|
||||||
|
return false;
|
||||||
|
if (std::isnan(v2))
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
return v1 < v2 || (v1 == v2 && a < b);
|
return v1 < v2 || (v1 == v2 && a < b);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -219,7 +240,7 @@ void partition(array& out, int axis, int kth) {
|
|||||||
StridedIterator md(data_ptr, axis_stride, kth);
|
StridedIterator md(data_ptr, axis_stride, kth);
|
||||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||||
|
|
||||||
std::nth_element(st, md, ed);
|
std::nth_element(st, md, ed, nan_aware_less<T>);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -276,6 +297,15 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
|||||||
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||||
auto v1 = data_ptr[a * in_stride];
|
auto v1 = data_ptr[a * in_stride];
|
||||||
auto v2 = data_ptr[b * in_stride];
|
auto v2 = data_ptr[b * in_stride];
|
||||||
|
|
||||||
|
// Handle NaNs (place them at the end)
|
||||||
|
if (std::is_floating_point<T>::value) {
|
||||||
|
if (std::isnan(v1))
|
||||||
|
return false;
|
||||||
|
if (std::isnan(v2))
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
return v1 < v2 || (v1 == v2 && a < b);
|
return v1 < v2 || (v1 == v2 && a < b);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@@ -19,11 +19,28 @@ METAL_FUNC void thread_swap(thread T& a, thread T& b) {
|
|||||||
b = w;
|
b = w;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename = void>
|
||||||
|
struct Init {
|
||||||
|
static constexpr constant T v = Limits<T>::max;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct Init<T, metal::enable_if_t<metal::is_floating_point_v<T>>> {
|
||||||
|
static constexpr constant T v = metal::numeric_limits<T>::quiet_NaN();
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct LessThan {
|
struct LessThan {
|
||||||
static constexpr constant T init = Limits<T>::max;
|
static constexpr constant T init = Init<T>::v;
|
||||||
|
METAL_FUNC bool operator()(T a, T b) const {
|
||||||
METAL_FUNC bool operator()(T a, T b) {
|
if constexpr (
|
||||||
|
metal::is_floating_point_v<T> || metal::is_same_v<T, complex64_t>) {
|
||||||
|
bool an = isnan(a);
|
||||||
|
bool bn = isnan(b);
|
||||||
|
if (an | bn) {
|
||||||
|
return (!an) & bn;
|
||||||
|
}
|
||||||
|
}
|
||||||
return a < b;
|
return a < b;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@@ -3100,8 +3100,6 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
out = mx.depends(b, c)
|
out = mx.depends(b, c)
|
||||||
self.assertTrue(mx.array_equal(out, b))
|
self.assertTrue(mx.array_equal(out, b))
|
||||||
|
|
||||||
|
|
||||||
class TestBroadcast(mlx_tests.MLXTestCase):
|
|
||||||
def test_broadcast_shapes(self):
|
def test_broadcast_shapes(self):
|
||||||
# Basic broadcasting
|
# Basic broadcasting
|
||||||
self.assertEqual(mx.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3))
|
self.assertEqual(mx.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3))
|
||||||
@@ -3140,6 +3138,12 @@ class TestBroadcast(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mx.broadcast_shapes()
|
mx.broadcast_shapes()
|
||||||
|
|
||||||
|
def test_sort_nan(self):
|
||||||
|
x = mx.array([3.0, mx.nan, 2.0, 0.0])
|
||||||
|
expected = mx.array([0.0, 2.0, 3.0, mx.nan])
|
||||||
|
self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True))
|
||||||
|
x = mx.array([3.0, mx.nan, 2.0, 0.0]) + 1j * mx.array([1.0] * 4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
Reference in New Issue
Block a user