fix dispatch threads for a few kernels (#1594)

This commit is contained in:
Awni Hannun 2024-11-18 08:35:25 -08:00 committed by GitHub
parent 16ec0556a0
commit 6931f84412
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 13 additions and 12 deletions

View File

@ -52,13 +52,14 @@ void explicit_gemm_conv_ND_gpu(
compute_encoder.set_bytes(conv_params, 2);
// Launch unfolding kernel
int tgp_x = std::min(conv_params.C, 64);
size_t tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
int tgp_y = 256 / tgp_x;
size_t tgp_y = 256 / tgp_x;
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
MTL::Size group_dims = MTL::Size(
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
@ -130,13 +131,14 @@ void explicit_gemm_conv_group_ND_gpu(
compute_encoder.set_bytes(conv_params, 2);
// Launch unfolding kernel
int tgp_x = std::min(conv_params.C, 64);
size_t tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
int tgp_y = 256 / tgp_x;
size_t tgp_y = 256 / tgp_x;
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
MTL::Size group_dims = MTL::Size(
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);

View File

@ -72,8 +72,9 @@ void CustomKernel::eval_gpu(
}
const auto [tx, ty, tz] = threadgroup_;
MTL::Size group_dims = MTL::Size(tx, ty, tz);
const auto [gx, gy, gz] = grid_;
MTL::Size group_dims =
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
compute_encoder.dispatch_threads(grid_dims, group_dims);

View File

@ -437,8 +437,7 @@ void steel_matmul(
// Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
auto group_dims = get_block_dims(N, M, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
@ -955,8 +954,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Launch enough thread groups for each output
MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
auto group_dims = get_block_dims(N, M, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}

View File

@ -273,7 +273,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
// organize into grid nkeys x elem_per_key
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
MTL::Size group_dims = MTL::Size(1, thread_group_size, 1);
auto group_dims = get_block_dims(num_keys, half_size + odd, 1);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(keys, 0);