mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Explicit barriers with concurrent dispatch (#977)
This commit is contained in:
@@ -57,13 +57,13 @@ void single_block_sort(
|
||||
}
|
||||
|
||||
// Prepare command encoder
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set inputs
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3);
|
||||
|
||||
@@ -131,7 +131,7 @@ void multi_block_sort(
|
||||
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
|
||||
|
||||
// Prepare command encoder
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
// Do blockwise sort
|
||||
{
|
||||
@@ -142,9 +142,9 @@ void multi_block_sort(
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, dev_vals_0, 1);
|
||||
set_array_buffer(compute_encoder, dev_idxs_0, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(dev_vals_0, 1);
|
||||
compute_encoder.set_output_array(dev_idxs_0, 2);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||
@@ -181,9 +181,9 @@ void multi_block_sort(
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, block_partitions, 0);
|
||||
set_array_buffer(compute_encoder, dev_vals_in, 1);
|
||||
set_array_buffer(compute_encoder, dev_idxs_in, 2);
|
||||
compute_encoder.set_output_array(block_partitions, 0);
|
||||
compute_encoder.set_input_array(dev_vals_in, 1);
|
||||
compute_encoder.set_input_array(dev_idxs_in, 2);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
|
||||
|
||||
@@ -202,11 +202,11 @@ void multi_block_sort(
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, block_partitions, 0);
|
||||
set_array_buffer(compute_encoder, dev_vals_in, 1);
|
||||
set_array_buffer(compute_encoder, dev_idxs_in, 2);
|
||||
set_array_buffer(compute_encoder, dev_vals_out, 3);
|
||||
set_array_buffer(compute_encoder, dev_idxs_out, 4);
|
||||
compute_encoder.set_input_array(block_partitions, 0);
|
||||
compute_encoder.set_input_array(dev_vals_in, 1);
|
||||
compute_encoder.set_input_array(dev_idxs_in, 2);
|
||||
compute_encoder.set_output_array(dev_vals_out, 3);
|
||||
compute_encoder.set_output_array(dev_idxs_out, 4);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&merge_tiles, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&n_blocks, sizeof(int), 7);
|
||||
|
||||
Reference in New Issue
Block a user