diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index 91d074c6b..2863ca938 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -82,9 +82,17 @@ void single_block_sort( compute_encoder.set_bytes(out_stride_segment_axis, 6); } else { compute_encoder.set_bytes(nc_dim, 5); - compute_encoder.set_vector_bytes(nc_shape, 6); - compute_encoder.set_vector_bytes(in_nc_str, 7); - compute_encoder.set_vector_bytes(out_nc_str, 8); + if (nc_shape.empty()) { + int shape = 0; + int64_t stride = 0; + compute_encoder.set_bytes(shape, 6); + compute_encoder.set_bytes(stride, 7); + compute_encoder.set_bytes(stride, 8); + } else { + compute_encoder.set_vector_bytes(nc_shape, 6); + compute_encoder.set_vector_bytes(in_nc_str, 7); + compute_encoder.set_vector_bytes(out_nc_str, 8); + } } MTL::Size group_dims = MTL::Size(bn, 1, 1); diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 9a5918011..7d010135a 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1956,6 +1956,12 @@ class TestOps(mlx_tests.MLXTestCase): b_mx = mx.sort(a_mx) self.assertTrue(np.array_equal(b_np, b_mx)) + # 1D strided sort + a = mx.array([[4, 3], [2, 1], [5, 4], [3, 2]]) + out = mx.argsort(a[:, 1]) + expected = mx.array([1, 3, 0, 2], dtype=mx.uint32) + self.assertTrue(mx.array_equal(out, expected)) + def test_partition(self): shape = (3, 4, 5) for dtype in ("int32", "float32"):