This commit is contained in:
Awni Hannun 2025-02-05 17:16:27 -08:00 committed by GitHub
parent ca305afdbe
commit 9174606d4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 4 deletions

View File

@ -38,10 +38,6 @@ void single_block_sort(
int size_sorted_axis = in.shape(axis);
int in_stride_sorted_axis = in.strides()[axis];
int out_stride_sorted_axis = out.strides()[axis];
int in_stride_segment_axis =
*std::min_element(in_nc_str.begin(), in_nc_str.end());
int out_stride_segment_axis =
*std::min_element(out_nc_str.begin(), out_nc_str.end());
// We can only use the contiguous kernel if the sorted axis
// has the largest or smallest stride.
@ -78,6 +74,20 @@ void single_block_sort(
compute_encoder.set_bytes(out_stride_sorted_axis, 4);
if (contiguous) {
int in_stride_segment_axis = INT32_MAX;
int out_stride_segment_axis = INT32_MAX;
for (int i = 0; i < in_nc_str.size(); i++) {
if (nc_shape[i] == 1) {
continue;
}
if (in_nc_str[i] > INT32_MAX || out_nc_str[i] > INT32_MAX) {
throw std::runtime_error("[Sort::eval_gpu] Stride too large.");
}
in_stride_segment_axis =
std::min(in_stride_segment_axis, static_cast<int>(in_nc_str[i]));
out_stride_segment_axis =
std::min(out_stride_segment_axis, static_cast<int>(out_nc_str[i]));
}
compute_encoder.set_bytes(in_stride_segment_axis, 5);
compute_encoder.set_bytes(out_stride_segment_axis, 6);
} else {

View File

@ -2010,6 +2010,15 @@ class TestOps(mlx_tests.MLXTestCase):
expected = mx.array([1, 3, 0, 2], dtype=mx.uint32)
self.assertTrue(mx.array_equal(out, expected))
# Test array with singleton dim
out = mx.sort(mx.array([1, 2, 3]), axis=0)
self.assertTrue(mx.array_equal(out, mx.array([1, 2, 3])))
x = np.random.uniform(size=(1, 4, 8, 1)).astype(np.float32)
y_np = np.sort(x, axis=-2)
y_mx = mx.sort(mx.array(x), axis=-2)
self.assertTrue(np.array_equal(y_np, y_mx))
def test_partition(self):
shape = (3, 4, 5)
for dtype in ("int32", "float32"):