mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-18 23:21:16 +08:00
Fix copy in the sort primitive (#1383)
This commit is contained in:
parent
0d302cd25b
commit
58dca7d846
@ -156,6 +156,7 @@ void copy_gpu_inplace(
|
|||||||
array& out,
|
array& out,
|
||||||
CopyType ctype,
|
CopyType ctype,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
|
assert(in.shape() == out.shape());
|
||||||
return copy_gpu_inplace(
|
return copy_gpu_inplace(
|
||||||
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
||||||
}
|
}
|
||||||
@ -167,6 +168,7 @@ void copy_gpu_inplace(
|
|||||||
int64_t ioffset,
|
int64_t ioffset,
|
||||||
CopyType ctype,
|
CopyType ctype,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
|
assert(in.shape() == out.shape());
|
||||||
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||||
return copy_gpu_inplace(
|
return copy_gpu_inplace(
|
||||||
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
|
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
|
||||||
|
@ -236,35 +236,21 @@ void multi_block_sort(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Copy outputs with appropriate strides
|
// Copy outputs with appropriate strides
|
||||||
array strided_out_arr = argsort ? dev_idxs_out : dev_vals_out;
|
auto strides = out.strides();
|
||||||
|
for (int ax = axis + 1; ax < strides.size(); ax++) {
|
||||||
if (axis == in.ndim() - 1) {
|
strides[ax] *= out.shape(axis);
|
||||||
copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s);
|
|
||||||
} else {
|
|
||||||
std::vector<int> strided_out_shape = in.shape();
|
|
||||||
int out_axis_shape = strided_out_shape[axis];
|
|
||||||
|
|
||||||
strided_out_shape.erase(strided_out_shape.begin() + axis);
|
|
||||||
strided_out_shape.push_back(out_axis_shape);
|
|
||||||
|
|
||||||
std::vector<size_t> strided_out_str(in.ndim(), 1);
|
|
||||||
for (int i = in.ndim() - 2; i >= 0; --i) {
|
|
||||||
strided_out_str[i] = strided_out_str[i + 1] * strided_out_shape[i + 1];
|
|
||||||
}
|
|
||||||
|
|
||||||
strided_out_str.erase(strided_out_str.end() - 1);
|
|
||||||
strided_out_str.insert(strided_out_str.begin() + axis, 1);
|
|
||||||
|
|
||||||
array strided_out_slice(in.shape(), out.dtype(), nullptr, {});
|
|
||||||
strided_out_slice.copy_shared_buffer(
|
|
||||||
strided_out_arr,
|
|
||||||
strided_out_str,
|
|
||||||
strided_out_arr.flags(),
|
|
||||||
strided_out_arr.size(),
|
|
||||||
0);
|
|
||||||
|
|
||||||
copy_gpu_inplace(strided_out_slice, out, CopyType::General, s);
|
|
||||||
}
|
}
|
||||||
|
strides[axis] = 1;
|
||||||
|
copy_gpu_inplace(
|
||||||
|
(argsort) ? dev_idxs_out : dev_vals_out,
|
||||||
|
out,
|
||||||
|
out.shape(),
|
||||||
|
strides,
|
||||||
|
out.strides(),
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
(axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General,
|
||||||
|
s);
|
||||||
|
|
||||||
// Clear copies
|
// Clear copies
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||||
|
Loading…
Reference in New Issue
Block a user