Explicit barriers with concurrent dispatch (#977)

This commit is contained in:
Awni Hannun
2024-04-10 21:45:31 -07:00
committed by GitHub
parent 8580d997ff
commit 12d4507ee3
21 changed files with 326 additions and 267 deletions

View File

@@ -57,13 +57,13 @@ void single_block_sort(
}
// Prepare command encoder
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Set inputs
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3);
@@ -131,7 +131,7 @@ void multi_block_sort(
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
// Prepare command encoder
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
// Do blockwise sort
{
@@ -142,9 +142,9 @@ void multi_block_sort(
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, dev_vals_0, 1);
set_array_buffer(compute_encoder, dev_idxs_0, 2);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(dev_vals_0, 1);
compute_encoder.set_output_array(dev_idxs_0, 2);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
@@ -181,9 +181,9 @@ void multi_block_sort(
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, block_partitions, 0);
set_array_buffer(compute_encoder, dev_vals_in, 1);
set_array_buffer(compute_encoder, dev_idxs_in, 2);
compute_encoder.set_output_array(block_partitions, 0);
compute_encoder.set_input_array(dev_vals_in, 1);
compute_encoder.set_input_array(dev_idxs_in, 2);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
@@ -202,11 +202,11 @@ void multi_block_sort(
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, block_partitions, 0);
set_array_buffer(compute_encoder, dev_vals_in, 1);
set_array_buffer(compute_encoder, dev_idxs_in, 2);
set_array_buffer(compute_encoder, dev_vals_out, 3);
set_array_buffer(compute_encoder, dev_idxs_out, 4);
compute_encoder.set_input_array(block_partitions, 0);
compute_encoder.set_input_array(dev_vals_in, 1);
compute_encoder.set_input_array(dev_idxs_in, 2);
compute_encoder.set_output_array(dev_vals_out, 3);
compute_encoder.set_output_array(dev_idxs_out, 4);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5);
compute_encoder->setBytes(&merge_tiles, sizeof(int), 6);
compute_encoder->setBytes(&n_blocks, sizeof(int), 7);