diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index 9eb9960e0..ad8c83e48 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include @@ -102,6 +102,11 @@ void multi_block_sort( int nc_dim = nc_shape.size(); + if (nc_dim == 0) { + nc_shape = {0}; + nc_str = {1}; + } + int size_sorted_axis = in.shape(axis); int stride_sorted_axis = in.strides()[axis]; @@ -143,8 +148,9 @@ void multi_block_sort( compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3); compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4); compute_encoder->setBytes(&nc_dim, sizeof(int), 5); - compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6); - compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 7); + compute_encoder->setBytes( + nc_shape.data(), nc_shape.size() * sizeof(int), 6); + compute_encoder->setBytes(nc_str.data(), nc_str.size() * sizeof(size_t), 7); MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1); @@ -158,7 +164,8 @@ void multi_block_sort( array dev_idxs_in = dev_idxs_0; array dev_vals_out = dev_vals_1; array dev_idxs_out = dev_idxs_1; - for (int merge_tiles = 2; merge_tiles <= n_blocks; merge_tiles *= 2) { + + for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) { dev_vals_in = ping ? dev_vals_1 : dev_vals_0; dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0; dev_vals_out = ping ? dev_vals_0 : dev_vals_1; diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 5fef45c64..e17c5fce9 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1597,6 +1597,16 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.array_equal(d_np, d_mx)) self.assertEqual(c_mx.dtype, mx.uint32) + # Test multi-block sort + a_np = np.random.normal(size=(32769,)).astype(np.float32) + a_mx = mx.array(a_np) + + b_np = np.sort(a_np) + b_mx = mx.sort(a_mx) + + self.assertTrue(np.array_equal(b_np, b_mx)) + self.assertEqual(b_mx.dtype, a_mx.dtype) + def test_partition(self): shape = (3, 4, 5) for dtype in ("int32", "float32"):