mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix strided sort bug (#1236)
* Use output strides in sort kernel * fix zero strides bug
This commit is contained in:
@@ -24,8 +24,11 @@ void single_block_sort(
|
||||
// Prepare shapes
|
||||
int n_rows = in.size() / in.shape(axis);
|
||||
|
||||
std::vector<size_t> nc_str = in.strides();
|
||||
nc_str.erase(nc_str.begin() + axis);
|
||||
std::vector<size_t> in_nc_str = in.strides();
|
||||
in_nc_str.erase(in_nc_str.begin() + axis);
|
||||
|
||||
std::vector<size_t> out_nc_str = out.strides();
|
||||
out_nc_str.erase(out_nc_str.begin() + axis);
|
||||
|
||||
std::vector<int> nc_shape = in.shape();
|
||||
nc_shape.erase(nc_shape.begin() + axis);
|
||||
@@ -33,21 +36,28 @@ void single_block_sort(
|
||||
int nc_dim = nc_shape.size();
|
||||
|
||||
int size_sorted_axis = in.shape(axis);
|
||||
int stride_sorted_axis = in.strides()[axis];
|
||||
int stride_segment_axis = *std::min_element(nc_str.begin(), nc_str.end());
|
||||
int in_stride_sorted_axis = in.strides()[axis];
|
||||
int out_stride_sorted_axis = out.strides()[axis];
|
||||
int in_stride_segment_axis =
|
||||
*std::min_element(in_nc_str.begin(), in_nc_str.end());
|
||||
int out_stride_segment_axis =
|
||||
*std::min_element(out_nc_str.begin(), out_nc_str.end());
|
||||
|
||||
// Check if remaining strides are contiguous
|
||||
bool contiguous_write = true;
|
||||
if (axis != in.ndim() - 1 && axis != 0) {
|
||||
for (int i = 0; i < nc_str.size() - 1; ++i) {
|
||||
size_t expected = nc_str[i + 1] * nc_str[i + 1];
|
||||
contiguous_write &= (nc_str[i] == expected);
|
||||
}
|
||||
}
|
||||
// We can only use the contiguous kernel if the sorted axis
|
||||
// has the largest or smallest stride.
|
||||
// We also need the input to be contiguous
|
||||
bool contiguous = in.flags().contiguous;
|
||||
auto check_strides = [](array x, int sort_stride) {
|
||||
int min_stride = *std::min_element(x.strides().begin(), x.strides().end());
|
||||
int max_stride = *std::max_element(x.strides().begin(), x.strides().end());
|
||||
return sort_stride == min_stride || sort_stride == max_stride;
|
||||
};
|
||||
contiguous &= check_strides(in, in_stride_sorted_axis);
|
||||
contiguous &= check_strides(out, out_stride_sorted_axis);
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
kname << (contiguous_write ? "c" : "nc");
|
||||
kname << (contiguous ? "c" : "nc");
|
||||
if (argsort) {
|
||||
kname << "arg";
|
||||
}
|
||||
@@ -64,14 +74,17 @@ void single_block_sort(
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&in_stride_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&out_stride_sorted_axis, sizeof(int), 4);
|
||||
|
||||
if (contiguous_write) {
|
||||
compute_encoder->setBytes(&stride_segment_axis, sizeof(int), 4);
|
||||
if (contiguous) {
|
||||
compute_encoder->setBytes(&in_stride_segment_axis, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&out_stride_segment_axis, sizeof(int), 6);
|
||||
} else {
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 4);
|
||||
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(in_nc_str.data(), nc_dim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(out_nc_str.data(), nc_dim * sizeof(size_t), 8);
|
||||
}
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
|
||||
Reference in New Issue
Block a user