fix argpartition + faster {arg} sorts / partitions (#1453)

This commit is contained in:
Awni Hannun
2024-10-03 14:21:25 -07:00
committed by GitHub
parent 5523d9c426
commit 1bdc038bf9
2 changed files with 59 additions and 26 deletions

View File

@@ -1954,6 +1954,17 @@ class TestOps(mlx_tests.MLXTestCase):
M = top_k_mx.shape[axis or 0]
self.assertEqual(M, (kth + N) % N)
def test_argpartition(self):
x = mx.broadcast_to(mx.array([1, 2, 3]), (2, 3))
out = mx.argpartition(x, kth=1, axis=0)
expected = mx.array([[0, 0, 0], [1, 1, 1]])
self.assertTrue(mx.array_equal(out, expected))
x = mx.array([[1, 2], [3, 4]]).T
out = mx.argpartition(x, kth=1, axis=0)
expected = mx.array([[0, 0], [1, 1]])
self.assertTrue(mx.array_equal(out, expected))
@unittest.skipIf(
os.getenv("LOW_MEMORY", None) is not None,
"This test requires a lot of memory",