mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-16 22:11:15 +08:00
Fix copy in the sort primitive (#1383)
This commit is contained in:
parent
0d302cd25b
commit
58dca7d846
@ -156,6 +156,7 @@ void copy_gpu_inplace(
|
||||
array& out,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
||||
}
|
||||
@ -167,6 +168,7 @@ void copy_gpu_inplace(
|
||||
int64_t ioffset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
assert(in.shape() == out.shape());
|
||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||
return copy_gpu_inplace(
|
||||
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
|
||||
|
@ -236,35 +236,21 @@ void multi_block_sort(
|
||||
}
|
||||
|
||||
// Copy outputs with appropriate strides
|
||||
array strided_out_arr = argsort ? dev_idxs_out : dev_vals_out;
|
||||
|
||||
if (axis == in.ndim() - 1) {
|
||||
copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s);
|
||||
} else {
|
||||
std::vector<int> strided_out_shape = in.shape();
|
||||
int out_axis_shape = strided_out_shape[axis];
|
||||
|
||||
strided_out_shape.erase(strided_out_shape.begin() + axis);
|
||||
strided_out_shape.push_back(out_axis_shape);
|
||||
|
||||
std::vector<size_t> strided_out_str(in.ndim(), 1);
|
||||
for (int i = in.ndim() - 2; i >= 0; --i) {
|
||||
strided_out_str[i] = strided_out_str[i + 1] * strided_out_shape[i + 1];
|
||||
}
|
||||
|
||||
strided_out_str.erase(strided_out_str.end() - 1);
|
||||
strided_out_str.insert(strided_out_str.begin() + axis, 1);
|
||||
|
||||
array strided_out_slice(in.shape(), out.dtype(), nullptr, {});
|
||||
strided_out_slice.copy_shared_buffer(
|
||||
strided_out_arr,
|
||||
strided_out_str,
|
||||
strided_out_arr.flags(),
|
||||
strided_out_arr.size(),
|
||||
0);
|
||||
|
||||
copy_gpu_inplace(strided_out_slice, out, CopyType::General, s);
|
||||
auto strides = out.strides();
|
||||
for (int ax = axis + 1; ax < strides.size(); ax++) {
|
||||
strides[ax] *= out.shape(axis);
|
||||
}
|
||||
strides[axis] = 1;
|
||||
copy_gpu_inplace(
|
||||
(argsort) ? dev_idxs_out : dev_vals_out,
|
||||
out,
|
||||
out.shape(),
|
||||
strides,
|
||||
out.strides(),
|
||||
0,
|
||||
0,
|
||||
(axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General,
|
||||
s);
|
||||
|
||||
// Clear copies
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
|
Loading…
Reference in New Issue
Block a user