Fully wrap the command encoder (#1572)

* fully wrap the command encoder

* use consistent style + fix extensions
This commit is contained in:
Awni Hannun
2024-11-08 11:50:21 -08:00
committed by GitHub
parent 59247c2b62
commit 9f0d5c12fc
27 changed files with 469 additions and 484 deletions

View File

@@ -68,29 +68,29 @@ void single_block_sort(
// Prepare command encoder
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
// Set inputs
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
compute_encoder->setBytes(&in_stride_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&out_stride_sorted_axis, sizeof(int), 4);
compute_encoder.set_bytes(size_sorted_axis, 2);
compute_encoder.set_bytes(in_stride_sorted_axis, 3);
compute_encoder.set_bytes(out_stride_sorted_axis, 4);
if (contiguous) {
compute_encoder->setBytes(&in_stride_segment_axis, sizeof(int), 5);
compute_encoder->setBytes(&out_stride_segment_axis, sizeof(int), 6);
compute_encoder.set_bytes(in_stride_segment_axis, 5);
compute_encoder.set_bytes(out_stride_segment_axis, 6);
} else {
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
compute_encoder->setBytes(in_nc_str.data(), nc_dim * sizeof(size_t), 7);
compute_encoder->setBytes(out_nc_str.data(), nc_dim * sizeof(size_t), 8);
compute_encoder.set_bytes(nc_dim, 5);
compute_encoder.set_vector_bytes(nc_shape, 6);
compute_encoder.set_vector_bytes(in_nc_str, 7);
compute_encoder.set_vector_bytes(out_nc_str, 8);
}
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void multi_block_sort(
@@ -152,22 +152,21 @@ void multi_block_sort(
<< type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn;
auto kernel =
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(dev_vals_0, 1);
compute_encoder.set_output_array(dev_idxs_0, 2);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
compute_encoder->setBytes(
nc_shape.data(), nc_shape.size() * sizeof(int), 6);
compute_encoder->setBytes(nc_str.data(), nc_str.size() * sizeof(size_t), 7);
compute_encoder.set_bytes(size_sorted_axis, 3);
compute_encoder.set_bytes(stride_sorted_axis, 4);
compute_encoder.set_bytes(nc_dim, 5);
compute_encoder.set_vector_bytes(nc_shape, 6);
compute_encoder.set_vector_bytes(nc_str, 7);
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
// Do merges
@@ -194,19 +193,19 @@ void multi_block_sort(
auto kernel =
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_output_array(block_partitions, 0);
compute_encoder.set_input_array(dev_vals_in, 1);
compute_encoder.set_input_array(dev_idxs_in, 2);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
compute_encoder->setBytes(&n_blocks, sizeof(int), 5);
compute_encoder.set_bytes(size_sorted_axis, 3);
compute_encoder.set_bytes(merge_tiles, 4);
compute_encoder.set_bytes(n_blocks, 5);
MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
// Do merge
@@ -217,21 +216,21 @@ void multi_block_sort(
auto kernel =
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(block_partitions, 0);
compute_encoder.set_input_array(dev_vals_in, 1);
compute_encoder.set_input_array(dev_idxs_in, 2);
compute_encoder.set_output_array(dev_vals_out, 3);
compute_encoder.set_output_array(dev_idxs_out, 4);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5);
compute_encoder->setBytes(&merge_tiles, sizeof(int), 6);
compute_encoder->setBytes(&n_blocks, sizeof(int), 7);
compute_encoder.set_bytes(size_sorted_axis, 5);
compute_encoder.set_bytes(merge_tiles, 6);
compute_encoder.set_bytes(n_blocks, 7);
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
}