Fix multi-block sort stride management (#1169)

* Fix multi-block sort stride management

* Add seed to tests
This commit is contained in:
Jagrit Digani
2024-05-31 11:10:54 -07:00
committed by GitHub
parent 9f0df51f8d
commit 76b6cece46
2 changed files with 33 additions and 9 deletions

View File

@@ -222,22 +222,24 @@ void multi_block_sort(
// Copy outputs with appropriate strides
array strided_out_arr = argsort ? dev_idxs_out : dev_vals_out;
if (axis == strided_out_arr.ndim() - 1) {
if (axis == in.ndim() - 1) {
copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s);
} else {
std::vector<int> strided_out_shape = strided_out_arr.shape();
std::vector<size_t> strided_out_str = strided_out_arr.strides();
std::vector<int> strided_out_shape = in.shape();
int out_axis_shape = strided_out_shape[axis];
int out_axis_str = strided_out_str[axis];
strided_out_shape.erase(strided_out_shape.begin() + axis);
strided_out_str.erase(strided_out_str.begin() + axis);
strided_out_shape.push_back(out_axis_shape);
strided_out_str.push_back(out_axis_str);
array strided_out_slice(strided_out_shape, out.dtype(), nullptr, {});
std::vector<size_t> strided_out_str(in.ndim(), 1);
for (int i = in.ndim() - 2; i >= 0; --i) {
strided_out_str[i] = strided_out_str[i + 1] * strided_out_shape[i + 1];
}
strided_out_str.erase(strided_out_str.end() - 1);
strided_out_str.insert(strided_out_str.begin() + axis, 1);
array strided_out_slice(in.shape(), out.dtype(), nullptr, {});
strided_out_slice.copy_shared_buffer(
strided_out_arr,
strided_out_str,