mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-17 23:08:11 +08:00
Modified sort behavior when running CPU or Metal to match NumPy/JAX (#2667)
* Modified sort behavior when running CPU or Metal to match NumPy/JAX sorting behavior. * Modified sort behavior when running CPU or Metal to match NumPy/JAX * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:

committed by
GitHub

parent
9bfc476d72
commit
9cbb1b0148
@@ -3100,8 +3100,6 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out = mx.depends(b, c)
|
||||
self.assertTrue(mx.array_equal(out, b))
|
||||
|
||||
|
||||
class TestBroadcast(mlx_tests.MLXTestCase):
|
||||
def test_broadcast_shapes(self):
|
||||
# Basic broadcasting
|
||||
self.assertEqual(mx.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3))
|
||||
@@ -3140,6 +3138,12 @@ class TestBroadcast(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
mx.broadcast_shapes()
|
||||
|
||||
def test_sort_nan(self):
|
||||
x = mx.array([3.0, mx.nan, 2.0, 0.0])
|
||||
expected = mx.array([0.0, 2.0, 3.0, mx.nan])
|
||||
self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True))
|
||||
x = mx.array([3.0, mx.nan, 2.0, 0.0]) + 1j * mx.array([1.0] * 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
Reference in New Issue
Block a user