This commit is contained in:
Awni Hannun
2025-02-05 17:16:27 -08:00
committed by GitHub
parent ca305afdbe
commit 9174606d4c
2 changed files with 23 additions and 4 deletions

View File

@@ -38,10 +38,6 @@ void single_block_sort(
int size_sorted_axis = in.shape(axis);
int in_stride_sorted_axis = in.strides()[axis];
int out_stride_sorted_axis = out.strides()[axis];
int in_stride_segment_axis =
*std::min_element(in_nc_str.begin(), in_nc_str.end());
int out_stride_segment_axis =
*std::min_element(out_nc_str.begin(), out_nc_str.end());
// We can only use the contiguous kernel if the sorted axis
// has the largest or smallest stride.
@@ -78,6 +74,20 @@ void single_block_sort(
compute_encoder.set_bytes(out_stride_sorted_axis, 4);
if (contiguous) {
int in_stride_segment_axis = INT32_MAX;
int out_stride_segment_axis = INT32_MAX;
for (int i = 0; i < in_nc_str.size(); i++) {
if (nc_shape[i] == 1) {
continue;
}
if (in_nc_str[i] > INT32_MAX || out_nc_str[i] > INT32_MAX) {
throw std::runtime_error("[Sort::eval_gpu] Stride too large.");
}
in_stride_segment_axis =
std::min(in_stride_segment_axis, static_cast<int>(in_nc_str[i]));
out_stride_segment_axis =
std::min(out_stride_segment_axis, static_cast<int>(out_nc_str[i]));
}
compute_encoder.set_bytes(in_stride_segment_axis, 5);
compute_encoder.set_bytes(out_stride_segment_axis, 6);
} else {