Fix GPU sort for large arrays (#1285)

* Fix GPU sort for large arrays
This commit is contained in:
Jagrit Digani
2024-07-24 14:37:10 -07:00
committed by GitHub
parent ebd7135b50
commit 7f914365fd
4 changed files with 34 additions and 25 deletions

View File

@@ -177,6 +177,8 @@ void multi_block_sort(
array dev_vals_out = dev_vals_1;
array dev_idxs_out = dev_idxs_1;
int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024;
for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) {
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
@@ -199,8 +201,9 @@ void multi_block_sort(
compute_encoder.set_input_array(dev_idxs_in, 2);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
compute_encoder->setBytes(&n_blocks, sizeof(int), 5);
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);