[CUDA] fix sort (#2550)

* [CUDA] fix sort

* fix test
This commit is contained in:
Awni Hannun
2025-08-27 19:48:43 -07:00
committed by GitHub
parent 31c6f6e33f
commit 7ef8a6f2d5
2 changed files with 19 additions and 5 deletions

View File

@@ -2191,6 +2191,12 @@ class TestOps(mlx_tests.MLXTestCase):
y_mx = mx.sort(mx.array(x), axis=-2)
self.assertTrue(np.array_equal(y_np, y_mx))
# Test many segments
a = mx.random.uniform(shape=(512, 128))
y_mx = mx.sort(a, axis=-1)
y_np = np.sort(np.array(a), axis=-1)
self.assertTrue(np.array_equal(y_np, y_mx))
def test_partition(self):
shape = (3, 4, 5)
for dtype in ("int32", "float32"):