diff --git a/mlx/backend/cpu/sort.cpp b/mlx/backend/cpu/sort.cpp index 0b8471d32..56e7b939c 100644 --- a/mlx/backend/cpu/sort.cpp +++ b/mlx/backend/cpu/sort.cpp @@ -15,6 +15,18 @@ namespace mlx::core { namespace { +// NaN-aware comparator that places NaNs at the end +template +bool nan_aware_less(T a, T b) { + if constexpr (std::is_floating_point_v || std::is_same_v) { + if (std::isnan(a)) + return false; + if (std::isnan(b)) + return true; + } + return a < b; +} + template struct StridedIterator { using iterator_category = std::random_access_iterator_tag; @@ -130,7 +142,7 @@ void sort(array& out, int axis) { StridedIterator st(data_ptr, axis_stride, 0); StridedIterator ed(data_ptr, axis_stride, axis_size); - std::stable_sort(st, ed); + std::stable_sort(st, ed, nan_aware_less); src_it.step(); } } @@ -184,6 +196,15 @@ void argsort(const array& in, array& out, int axis) { std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) { auto v1 = data_ptr[a * in_stride]; auto v2 = data_ptr[b * in_stride]; + + // Handle NaNs (place them at the end) + if (std::is_floating_point::value) { + if (std::isnan(v1)) + return false; + if (std::isnan(v2)) + return true; + } + return v1 < v2 || (v1 == v2 && a < b); }); } @@ -219,7 +240,7 @@ void partition(array& out, int axis, int kth) { StridedIterator md(data_ptr, axis_stride, kth); StridedIterator ed(data_ptr, axis_stride, axis_size); - std::nth_element(st, md, ed); + std::nth_element(st, md, ed, nan_aware_less); } } @@ -276,6 +297,15 @@ void argpartition(const array& in, array& out, int axis, int kth) { std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) { auto v1 = data_ptr[a * in_stride]; auto v2 = data_ptr[b * in_stride]; + + // Handle NaNs (place them at the end) + if (std::is_floating_point::value) { + if (std::isnan(v1)) + return false; + if (std::isnan(v2)) + return true; + } + return v1 < v2 || (v1 == v2 && a < b); }); } diff --git a/mlx/backend/metal/kernels/sort.h b/mlx/backend/metal/kernels/sort.h index 5823e4300..c4392878e 100644 --- a/mlx/backend/metal/kernels/sort.h +++ b/mlx/backend/metal/kernels/sort.h @@ -19,11 +19,28 @@ METAL_FUNC void thread_swap(thread T& a, thread T& b) { b = w; } +template +struct Init { + static constexpr constant T v = Limits::max; +}; + +template +struct Init>> { + static constexpr constant T v = metal::numeric_limits::quiet_NaN(); +}; + template struct LessThan { - static constexpr constant T init = Limits::max; - - METAL_FUNC bool operator()(T a, T b) { + static constexpr constant T init = Init::v; + METAL_FUNC bool operator()(T a, T b) const { + if constexpr ( + metal::is_floating_point_v || metal::is_same_v) { + bool an = isnan(a); + bool bn = isnan(b); + if (an | bn) { + return (!an) & bn; + } + } return a < b; } }; diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index af262ed51..30ba12417 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3100,8 +3100,6 @@ class TestOps(mlx_tests.MLXTestCase): out = mx.depends(b, c) self.assertTrue(mx.array_equal(out, b)) - -class TestBroadcast(mlx_tests.MLXTestCase): def test_broadcast_shapes(self): # Basic broadcasting self.assertEqual(mx.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3)) @@ -3140,6 +3138,12 @@ class TestBroadcast(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.broadcast_shapes() + def test_sort_nan(self): + x = mx.array([3.0, mx.nan, 2.0, 0.0]) + expected = mx.array([0.0, 2.0, 3.0, mx.nan]) + self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True)) + x = mx.array([3.0, mx.nan, 2.0, 0.0]) + 1j * mx.array([1.0] * 4) + if __name__ == "__main__": mlx_tests.MLXTestRunner()