mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix multi-block sort stride management (#1169)
* Fix multi-block sort stride management * Add seed to tests
This commit is contained in:
		| @@ -1754,6 +1754,9 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|                     self.assertTrue(np.array_equal(d_np, d_mx)) | ||||
|                     self.assertEqual(c_mx.dtype, mx.uint32) | ||||
|  | ||||
|         # Set random seed | ||||
|         np.random.seed(0) | ||||
|  | ||||
|         # Test multi-block sort | ||||
|         a_np = np.random.normal(size=(32769,)).astype(np.float32) | ||||
|         a_mx = mx.array(a_np) | ||||
| @@ -1764,6 +1767,25 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         self.assertTrue(np.array_equal(b_np, b_mx)) | ||||
|         self.assertEqual(b_mx.dtype, a_mx.dtype) | ||||
|  | ||||
|         # Test multi-dum multi-block sort | ||||
|         a_np = np.random.normal(size=(2, 4, 32769)).astype(np.float32) | ||||
|         a_mx = mx.array(a_np) | ||||
|  | ||||
|         b_np = np.sort(a_np, axis=-1) | ||||
|         b_mx = mx.sort(a_mx, axis=-1) | ||||
|  | ||||
|         self.assertTrue(np.array_equal(b_np, b_mx)) | ||||
|         self.assertEqual(b_mx.dtype, a_mx.dtype) | ||||
|  | ||||
|         a_np = np.random.normal(size=(2, 32769, 4)).astype(np.float32) | ||||
|         a_mx = mx.array(a_np) | ||||
|  | ||||
|         b_np = np.sort(a_np, axis=1) | ||||
|         b_mx = mx.sort(a_mx, axis=1) | ||||
|  | ||||
|         self.assertTrue(np.array_equal(b_np, b_mx)) | ||||
|         self.assertEqual(b_mx.dtype, a_mx.dtype) | ||||
|  | ||||
|     def test_partition(self): | ||||
|         shape = (3, 4, 5) | ||||
|         for dtype in ("int32", "float32"): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jagrit Digani
					Jagrit Digani