Modified sort behavior when running CPU or Metal to match NumPy/JAX (#2667)

* Modified sort behavior when running CPU or Metal to match NumPy/JAX sorting behavior.

* Modified sort behavior when running CPU or Metal to match NumPy/JAX

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Manuel Villanueva
2025-10-13 16:36:45 -05:00
committed by GitHub
parent 9bfc476d72
commit 9cbb1b0148
3 changed files with 58 additions and 7 deletions

View File

@@ -3100,8 +3100,6 @@ class TestOps(mlx_tests.MLXTestCase):
out = mx.depends(b, c)
self.assertTrue(mx.array_equal(out, b))
class TestBroadcast(mlx_tests.MLXTestCase):
def test_broadcast_shapes(self):
# Basic broadcasting
self.assertEqual(mx.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3))
@@ -3140,6 +3138,12 @@ class TestBroadcast(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
mx.broadcast_shapes()
def test_sort_nan(self):
x = mx.array([3.0, mx.nan, 2.0, 0.0])
expected = mx.array([0.0, 2.0, 3.0, mx.nan])
self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True))
x = mx.array([3.0, mx.nan, 2.0, 0.0]) + 1j * mx.array([1.0] * 4)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()