mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fully wrap the command encoder (#1572)
* fully wrap the command encoder * use consistent style + fix extensions
This commit is contained in:
@@ -68,29 +68,29 @@ void single_block_sort(
|
||||
|
||||
// Prepare command encoder
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set inputs
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&in_stride_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&out_stride_sorted_axis, sizeof(int), 4);
|
||||
compute_encoder.set_bytes(size_sorted_axis, 2);
|
||||
compute_encoder.set_bytes(in_stride_sorted_axis, 3);
|
||||
compute_encoder.set_bytes(out_stride_sorted_axis, 4);
|
||||
|
||||
if (contiguous) {
|
||||
compute_encoder->setBytes(&in_stride_segment_axis, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&out_stride_segment_axis, sizeof(int), 6);
|
||||
compute_encoder.set_bytes(in_stride_segment_axis, 5);
|
||||
compute_encoder.set_bytes(out_stride_segment_axis, 6);
|
||||
} else {
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(in_nc_str.data(), nc_dim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(out_nc_str.data(), nc_dim * sizeof(size_t), 8);
|
||||
compute_encoder.set_bytes(nc_dim, 5);
|
||||
compute_encoder.set_vector_bytes(nc_shape, 6);
|
||||
compute_encoder.set_vector_bytes(in_nc_str, 7);
|
||||
compute_encoder.set_vector_bytes(out_nc_str, 8);
|
||||
}
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void multi_block_sort(
|
||||
@@ -152,22 +152,21 @@ void multi_block_sort(
|
||||
<< type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn;
|
||||
auto kernel =
|
||||
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(dev_vals_0, 1);
|
||||
compute_encoder.set_output_array(dev_idxs_0, 2);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||
compute_encoder->setBytes(
|
||||
nc_shape.data(), nc_shape.size() * sizeof(int), 6);
|
||||
compute_encoder->setBytes(nc_str.data(), nc_str.size() * sizeof(size_t), 7);
|
||||
compute_encoder.set_bytes(size_sorted_axis, 3);
|
||||
compute_encoder.set_bytes(stride_sorted_axis, 4);
|
||||
compute_encoder.set_bytes(nc_dim, 5);
|
||||
compute_encoder.set_vector_bytes(nc_shape, 6);
|
||||
compute_encoder.set_vector_bytes(nc_str, 7);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do merges
|
||||
@@ -194,19 +193,19 @@ void multi_block_sort(
|
||||
|
||||
auto kernel =
|
||||
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_output_array(block_partitions, 0);
|
||||
compute_encoder.set_input_array(dev_vals_in, 1);
|
||||
compute_encoder.set_input_array(dev_idxs_in, 2);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&n_blocks, sizeof(int), 5);
|
||||
compute_encoder.set_bytes(size_sorted_axis, 3);
|
||||
compute_encoder.set_bytes(merge_tiles, 4);
|
||||
compute_encoder.set_bytes(n_blocks, 5);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do merge
|
||||
@@ -217,21 +216,21 @@ void multi_block_sort(
|
||||
|
||||
auto kernel =
|
||||
get_mb_sort_kernel(d, kname.str(), dev_vals_0, dev_idxs_0, bn, tn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(block_partitions, 0);
|
||||
compute_encoder.set_input_array(dev_vals_in, 1);
|
||||
compute_encoder.set_input_array(dev_idxs_in, 2);
|
||||
compute_encoder.set_output_array(dev_vals_out, 3);
|
||||
compute_encoder.set_output_array(dev_idxs_out, 4);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&merge_tiles, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&n_blocks, sizeof(int), 7);
|
||||
compute_encoder.set_bytes(size_sorted_axis, 5);
|
||||
compute_encoder.set_bytes(merge_tiles, 6);
|
||||
compute_encoder.set_bytes(n_blocks, 7);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
||||
|
||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user