Explicit barriers with concurrent dispatch (#977)

This commit is contained in:
Awni Hannun
2024-04-10 21:45:31 -07:00
committed by GitHub
parent 8580d997ff
commit 12d4507ee3
21 changed files with 326 additions and 267 deletions

View File

@@ -41,12 +41,12 @@ void explicit_gemm_conv_ND_gpu(
// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, in_unfolded, 1);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
@@ -140,7 +140,7 @@ void slow_conv_2D_gpu(
<< "_tm" << tm << "_tn" << tn;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -153,9 +153,9 @@ void slow_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(bm, bn, 1);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
@@ -241,7 +241,7 @@ void implicit_gemm_conv_2D_gpu(
<< "_filter_" << (small_filter ? 's' : 'l');
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -254,9 +254,9 @@ void implicit_gemm_conv_2D_gpu(
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
// Encode arrays
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
// Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
@@ -394,7 +394,7 @@ void implicit_gemm_conv_2D_general_gpu(
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
@@ -408,9 +408,9 @@ void implicit_gemm_conv_2D_general_gpu(
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
// Encode arrays
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, wt, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
// Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
@@ -511,12 +511,12 @@ void winograd_conv_2D_gpu(
std::ostringstream kname;
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
<< bc;
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, wt, 0);
set_array_buffer(compute_encoder, filt_wg, 1);
compute_encoder.set_input_array(wt, 0);
compute_encoder.set_output_array(filt_wg, 1);
compute_encoder->setBytes(&C_c, sizeof(int), 2);
compute_encoder->setBytes(&O_c, sizeof(int), 3);
@@ -539,12 +539,12 @@ void winograd_conv_2D_gpu(
std::ostringstream kname;
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
<< bc;
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in_padded, 0);
set_array_buffer(compute_encoder, inp_wg, 1);
compute_encoder.set_input_array(in_padded, 0);
compute_encoder.set_output_array(inp_wg, 1);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
@@ -587,12 +587,12 @@ void winograd_conv_2D_gpu(
std::ostringstream kname;
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
<< bc;
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, out_wg, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder.set_input_array(out_wg, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);