mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
fix argpartition + faster {arg} sorts / partitions (#1453)
This commit is contained in:
@@ -1954,6 +1954,17 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
M = top_k_mx.shape[axis or 0]
|
||||
self.assertEqual(M, (kth + N) % N)
|
||||
|
||||
def test_argpartition(self):
|
||||
x = mx.broadcast_to(mx.array([1, 2, 3]), (2, 3))
|
||||
out = mx.argpartition(x, kth=1, axis=0)
|
||||
expected = mx.array([[0, 0, 0], [1, 1, 1]])
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
x = mx.array([[1, 2], [3, 4]]).T
|
||||
out = mx.argpartition(x, kth=1, axis=0)
|
||||
expected = mx.array([[0, 0], [1, 1]])
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
@unittest.skipIf(
|
||||
os.getenv("LOW_MEMORY", None) is not None,
|
||||
"This test requires a lot of memory",
|
||||
|
Reference in New Issue
Block a user