Fully wrap the command encoder (#1572)

* fully wrap the command encoder

* use consistent style + fix extensions
This commit is contained in:
Awni Hannun
2024-11-08 11:50:21 -08:00
committed by GitHub
parent 59247c2b62
commit 9f0d5c12fc
27 changed files with 469 additions and 484 deletions

View File

@@ -44,12 +44,12 @@ void explicit_gemm_conv_ND_gpu(
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
compute_encoder.set_bytes(conv_params, 2);
// Launch unfolding kernel
int tgp_x = std::min(conv_params.C, 64);
@@ -60,7 +60,7 @@ void explicit_gemm_conv_ND_gpu(
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatch_threads(grid_dims, group_dims);
// Reshape weight
std::vector<int> wt_reshape{implicit_K, implicit_N};
@@ -122,12 +122,12 @@ void explicit_gemm_conv_group_ND_gpu(
<< N;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
compute_encoder.set_bytes(conv_params, 2);
// Launch unfolding kernel
int tgp_x = std::min(conv_params.C, 64);
@@ -138,7 +138,7 @@ void explicit_gemm_conv_group_ND_gpu(
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
compute_encoder.dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatch_threads(grid_dims, group_dims);
// Transpose kernel weights so that we can slice them by contiguous chunks
// of channel groups.
@@ -237,7 +237,7 @@ void slow_conv_2D_gpu(
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
@@ -252,8 +252,8 @@ void slow_conv_2D_gpu(
compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.set_bytes(conv_params, 3);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_gpu(
@@ -352,7 +352,7 @@ void implicit_gemm_conv_2D_gpu(
wn,
n_channel_specialization,
small_filter);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
// Deduce grid launch dimensions
int tile = 1 << swizzle_log;
@@ -368,11 +368,11 @@ void implicit_gemm_conv_2D_gpu(
compute_encoder.set_output_array(out, 2);
// Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
compute_encoder.set_bytes(conv_params, 3);
compute_encoder.set_bytes(gemm_params, 4);
// Launch kernel
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_general_gpu(
@@ -506,7 +506,7 @@ void implicit_gemm_conv_2D_general_gpu(
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel =
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
// Deduce grid launch dimensions
int tile = 1 << swizzle_log;
@@ -523,17 +523,15 @@ void implicit_gemm_conv_2D_general_gpu(
compute_encoder.set_output_array(out, 2);
// Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5);
compute_encoder.set_bytes(conv_params, 3);
compute_encoder.set_bytes(gemm_params, 4);
compute_encoder.set_bytes(jump_params, 5);
compute_encoder->setBytes(
base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6);
compute_encoder->setBytes(
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
compute_encoder.set_vector_bytes(base_h, 6);
compute_encoder.set_vector_bytes(base_w, 7);
// Launch kernel
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void winograd_conv_2D_gpu(
@@ -622,18 +620,18 @@ void winograd_conv_2D_gpu(
<< bc;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(wt, 0);
compute_encoder.set_output_array(filt_wg, 1);
compute_encoder->setBytes(&C_c, sizeof(int), 2);
compute_encoder->setBytes(&O_c, sizeof(int), 3);
compute_encoder.set_bytes(C_c, 2);
compute_encoder.set_bytes(O_c, 3);
MTL::Size group_dims = MTL::Size(32, bo, 1);
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
// Do input transform
@@ -650,18 +648,17 @@ void winograd_conv_2D_gpu(
<< bc;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in_padded, 0);
compute_encoder.set_output_array(inp_wg, 1);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
compute_encoder.set_bytes(conv_params_updated, 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
// Do batched gemm
@@ -698,18 +695,17 @@ void winograd_conv_2D_gpu(
<< bc;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(out_wg, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
compute_encoder.set_bytes(conv_params_updated, 2);
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
}