Fix copy in the sort primitive (#1383)

This commit is contained in:
Angelos Katharopoulos 2024-08-31 08:32:14 -07:00 committed by GitHub
parent 0d302cd25b
commit 58dca7d846
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 28 deletions

View File

@ -156,6 +156,7 @@ void copy_gpu_inplace(
array& out,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
return copy_gpu_inplace(
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
}
@ -167,6 +168,7 @@ void copy_gpu_inplace(
int64_t ioffset,
CopyType ctype,
const Stream& s) {
assert(in.shape() == out.shape());
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
return copy_gpu_inplace(
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);

View File

@ -236,35 +236,21 @@ void multi_block_sort(
}
// Copy outputs with appropriate strides
array strided_out_arr = argsort ? dev_idxs_out : dev_vals_out;
if (axis == in.ndim() - 1) {
copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s);
} else {
std::vector<int> strided_out_shape = in.shape();
int out_axis_shape = strided_out_shape[axis];
strided_out_shape.erase(strided_out_shape.begin() + axis);
strided_out_shape.push_back(out_axis_shape);
std::vector<size_t> strided_out_str(in.ndim(), 1);
for (int i = in.ndim() - 2; i >= 0; --i) {
strided_out_str[i] = strided_out_str[i + 1] * strided_out_shape[i + 1];
}
strided_out_str.erase(strided_out_str.end() - 1);
strided_out_str.insert(strided_out_str.begin() + axis, 1);
array strided_out_slice(in.shape(), out.dtype(), nullptr, {});
strided_out_slice.copy_shared_buffer(
strided_out_arr,
strided_out_str,
strided_out_arr.flags(),
strided_out_arr.size(),
0);
copy_gpu_inplace(strided_out_slice, out, CopyType::General, s);
auto strides = out.strides();
for (int ax = axis + 1; ax < strides.size(); ax++) {
strides[ax] *= out.shape(axis);
}
strides[axis] = 1;
copy_gpu_inplace(
(argsort) ? dev_idxs_out : dev_vals_out,
out,
out.shape(),
strides,
out.strides(),
0,
0,
(axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General,
s);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(