Fix multiblock sort limits (#906)

* Fix multiblock sort limits

* Fix metal validation error
This commit is contained in:
Jagrit Digani
2024-03-26 14:00:00 -07:00
committed by GitHub
parent 5611e1a95e
commit 925014b661
2 changed files with 21 additions and 4 deletions

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
@@ -102,6 +102,11 @@ void multi_block_sort(
int nc_dim = nc_shape.size();
if (nc_dim == 0) {
nc_shape = {0};
nc_str = {1};
}
int size_sorted_axis = in.shape(axis);
int stride_sorted_axis = in.strides()[axis];
@@ -143,8 +148,9 @@ void multi_block_sort(
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 7);
compute_encoder->setBytes(
nc_shape.data(), nc_shape.size() * sizeof(int), 6);
compute_encoder->setBytes(nc_str.data(), nc_str.size() * sizeof(size_t), 7);
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
@@ -158,7 +164,8 @@ void multi_block_sort(
array dev_idxs_in = dev_idxs_0;
array dev_vals_out = dev_vals_1;
array dev_idxs_out = dev_idxs_1;
for (int merge_tiles = 2; merge_tiles <= n_blocks; merge_tiles *= 2) {
for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) {
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
dev_vals_out = ping ? dev_vals_0 : dev_vals_1;