Fix multiblock sort limits (#906)

* Fix multiblock sort limits

* Fix metal validation error
This commit is contained in:
Jagrit Digani
2024-03-26 14:00:00 -07:00
committed by GitHub
parent 5611e1a95e
commit 925014b661
2 changed files with 21 additions and 4 deletions

View File

@@ -1597,6 +1597,16 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.array_equal(d_np, d_mx))
self.assertEqual(c_mx.dtype, mx.uint32)
# Test multi-block sort
a_np = np.random.normal(size=(32769,)).astype(np.float32)
a_mx = mx.array(a_np)
b_np = np.sort(a_np)
b_mx = mx.sort(a_mx)
self.assertTrue(np.array_equal(b_np, b_mx))
self.assertEqual(b_mx.dtype, a_mx.dtype)
def test_partition(self):
shape = (3, 4, 5)
for dtype in ("int32", "float32"):