From 9174606d4c33db79adec02f4823905edc5d12552 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 5 Feb 2025 17:16:27 -0800 Subject: [PATCH] fix sort (#1835) --- mlx/backend/metal/sort.cpp | 18 ++++++++++++++---- python/tests/test_ops.py | 9 +++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index 39036828e..ea16fbf89 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -38,10 +38,6 @@ void single_block_sort( int size_sorted_axis = in.shape(axis); int in_stride_sorted_axis = in.strides()[axis]; int out_stride_sorted_axis = out.strides()[axis]; - int in_stride_segment_axis = - *std::min_element(in_nc_str.begin(), in_nc_str.end()); - int out_stride_segment_axis = - *std::min_element(out_nc_str.begin(), out_nc_str.end()); // We can only use the contiguous kernel if the sorted axis // has the largest or smallest stride. @@ -78,6 +74,20 @@ void single_block_sort( compute_encoder.set_bytes(out_stride_sorted_axis, 4); if (contiguous) { + int in_stride_segment_axis = INT32_MAX; + int out_stride_segment_axis = INT32_MAX; + for (int i = 0; i < in_nc_str.size(); i++) { + if (nc_shape[i] == 1) { + continue; + } + if (in_nc_str[i] > INT32_MAX || out_nc_str[i] > INT32_MAX) { + throw std::runtime_error("[Sort::eval_gpu] Stride too large."); + } + in_stride_segment_axis = + std::min(in_stride_segment_axis, static_cast(in_nc_str[i])); + out_stride_segment_axis = + std::min(out_stride_segment_axis, static_cast(out_nc_str[i])); + } compute_encoder.set_bytes(in_stride_segment_axis, 5); compute_encoder.set_bytes(out_stride_segment_axis, 6); } else { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index ce69c29f1..31a65f524 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2010,6 +2010,15 @@ class TestOps(mlx_tests.MLXTestCase): expected = mx.array([1, 3, 0, 2], dtype=mx.uint32) self.assertTrue(mx.array_equal(out, expected)) + # Test array with singleton dim + out = mx.sort(mx.array([1, 2, 3]), axis=0) + self.assertTrue(mx.array_equal(out, mx.array([1, 2, 3]))) + + x = np.random.uniform(size=(1, 4, 8, 1)).astype(np.float32) + y_np = np.sort(x, axis=-2) + y_mx = mx.sort(mx.array(x), axis=-2) + self.assertTrue(np.array_equal(y_np, y_mx)) + def test_partition(self): shape = (3, 4, 5) for dtype in ("int32", "float32"):