mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Fix multi-block sort stride management (#1169)
* Fix multi-block sort stride management * Add seed to tests
This commit is contained in:
@@ -222,22 +222,24 @@ void multi_block_sort(
|
||||
// Copy outputs with appropriate strides
|
||||
array strided_out_arr = argsort ? dev_idxs_out : dev_vals_out;
|
||||
|
||||
if (axis == strided_out_arr.ndim() - 1) {
|
||||
if (axis == in.ndim() - 1) {
|
||||
copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s);
|
||||
} else {
|
||||
std::vector<int> strided_out_shape = strided_out_arr.shape();
|
||||
std::vector<size_t> strided_out_str = strided_out_arr.strides();
|
||||
|
||||
std::vector<int> strided_out_shape = in.shape();
|
||||
int out_axis_shape = strided_out_shape[axis];
|
||||
int out_axis_str = strided_out_str[axis];
|
||||
|
||||
strided_out_shape.erase(strided_out_shape.begin() + axis);
|
||||
strided_out_str.erase(strided_out_str.begin() + axis);
|
||||
|
||||
strided_out_shape.push_back(out_axis_shape);
|
||||
strided_out_str.push_back(out_axis_str);
|
||||
|
||||
array strided_out_slice(strided_out_shape, out.dtype(), nullptr, {});
|
||||
std::vector<size_t> strided_out_str(in.ndim(), 1);
|
||||
for (int i = in.ndim() - 2; i >= 0; --i) {
|
||||
strided_out_str[i] = strided_out_str[i + 1] * strided_out_shape[i + 1];
|
||||
}
|
||||
|
||||
strided_out_str.erase(strided_out_str.end() - 1);
|
||||
strided_out_str.insert(strided_out_str.begin() + axis, 1);
|
||||
|
||||
array strided_out_slice(in.shape(), out.dtype(), nullptr, {});
|
||||
strided_out_slice.copy_shared_buffer(
|
||||
strided_out_arr,
|
||||
strided_out_str,
|
||||
|
Reference in New Issue
Block a user