mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	fix argpartition + faster {arg} sorts / partitions (#1453)
This commit is contained in:
		| @@ -1954,6 +1954,17 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|                             M = top_k_mx.shape[axis or 0] | ||||
|                             self.assertEqual(M, (kth + N) % N) | ||||
|  | ||||
|     def test_argpartition(self): | ||||
|         x = mx.broadcast_to(mx.array([1, 2, 3]), (2, 3)) | ||||
|         out = mx.argpartition(x, kth=1, axis=0) | ||||
|         expected = mx.array([[0, 0, 0], [1, 1, 1]]) | ||||
|         self.assertTrue(mx.array_equal(out, expected)) | ||||
|  | ||||
|         x = mx.array([[1, 2], [3, 4]]).T | ||||
|         out = mx.argpartition(x, kth=1, axis=0) | ||||
|         expected = mx.array([[0, 0], [1, 1]]) | ||||
|         self.assertTrue(mx.array_equal(out, expected)) | ||||
|  | ||||
|     @unittest.skipIf( | ||||
|         os.getenv("LOW_MEMORY", None) is not None, | ||||
|         "This test requires a lot of memory", | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun