diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index d594fccce..f4d6e63bd 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -222,22 +222,24 @@ void multi_block_sort( // Copy outputs with appropriate strides array strided_out_arr = argsort ? dev_idxs_out : dev_vals_out; - if (axis == strided_out_arr.ndim() - 1) { + if (axis == in.ndim() - 1) { copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s); } else { - std::vector strided_out_shape = strided_out_arr.shape(); - std::vector strided_out_str = strided_out_arr.strides(); - + std::vector strided_out_shape = in.shape(); int out_axis_shape = strided_out_shape[axis]; - int out_axis_str = strided_out_str[axis]; strided_out_shape.erase(strided_out_shape.begin() + axis); - strided_out_str.erase(strided_out_str.begin() + axis); - strided_out_shape.push_back(out_axis_shape); - strided_out_str.push_back(out_axis_str); - array strided_out_slice(strided_out_shape, out.dtype(), nullptr, {}); + std::vector strided_out_str(in.ndim(), 1); + for (int i = in.ndim() - 2; i >= 0; --i) { + strided_out_str[i] = strided_out_str[i + 1] * strided_out_shape[i + 1]; + } + + strided_out_str.erase(strided_out_str.end() - 1); + strided_out_str.insert(strided_out_str.begin() + axis, 1); + + array strided_out_slice(in.shape(), out.dtype(), nullptr, {}); strided_out_slice.copy_shared_buffer( strided_out_arr, strided_out_str, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index fb724d2c9..8ee2412bf 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1754,6 +1754,9 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.array_equal(d_np, d_mx)) self.assertEqual(c_mx.dtype, mx.uint32) + # Set random seed + np.random.seed(0) + # Test multi-block sort a_np = np.random.normal(size=(32769,)).astype(np.float32) a_mx = mx.array(a_np) @@ -1764,6 +1767,25 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.array_equal(b_np, b_mx)) self.assertEqual(b_mx.dtype, a_mx.dtype) + # Test multi-dum multi-block sort + a_np = np.random.normal(size=(2, 4, 32769)).astype(np.float32) + a_mx = mx.array(a_np) + + b_np = np.sort(a_np, axis=-1) + b_mx = mx.sort(a_mx, axis=-1) + + self.assertTrue(np.array_equal(b_np, b_mx)) + self.assertEqual(b_mx.dtype, a_mx.dtype) + + a_np = np.random.normal(size=(2, 32769, 4)).astype(np.float32) + a_mx = mx.array(a_np) + + b_np = np.sort(a_np, axis=1) + b_mx = mx.sort(a_mx, axis=1) + + self.assertTrue(np.array_equal(b_np, b_mx)) + self.assertEqual(b_mx.dtype, a_mx.dtype) + def test_partition(self): shape = (3, 4, 5) for dtype in ("int32", "float32"):