mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Explicit barriers with concurrent dispatch (#977)
This commit is contained in:
@@ -41,12 +41,12 @@ void explicit_gemm_conv_ND_gpu(
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, in_unfolded, 1);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(in_unfolded, 1);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
|
||||
|
||||
@@ -140,7 +140,7 @@ void slow_conv_2D_gpu(
|
||||
<< "_tm" << tm << "_tn" << tn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -153,9 +153,9 @@ void slow_conv_2D_gpu(
|
||||
MTL::Size group_dims = MTL::Size(bm, bn, 1);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
@@ -241,7 +241,7 @@ void implicit_gemm_conv_2D_gpu(
|
||||
<< "_filter_" << (small_filter ? 's' : 'l');
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -254,9 +254,9 @@ void implicit_gemm_conv_2D_gpu(
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
|
||||
|
||||
// Encode arrays
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
@@ -394,7 +394,7 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
<< "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@@ -408,9 +408,9 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
||||
|
||||
// Encode arrays
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, wt, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode params
|
||||
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
|
||||
@@ -511,12 +511,12 @@ void winograd_conv_2D_gpu(
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, wt, 0);
|
||||
set_array_buffer(compute_encoder, filt_wg, 1);
|
||||
compute_encoder.set_input_array(wt, 0);
|
||||
compute_encoder.set_output_array(filt_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(&C_c, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&O_c, sizeof(int), 3);
|
||||
@@ -539,12 +539,12 @@ void winograd_conv_2D_gpu(
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, in_padded, 0);
|
||||
set_array_buffer(compute_encoder, inp_wg, 1);
|
||||
compute_encoder.set_input_array(in_padded, 0);
|
||||
compute_encoder.set_output_array(inp_wg, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
@@ -587,12 +587,12 @@ void winograd_conv_2D_gpu(
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
|
||||
<< bc;
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, out_wg, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder.set_input_array(out_wg, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
compute_encoder->setBytes(
|
||||
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
|
||||
|
||||
Reference in New Issue
Block a user