mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 04:51:13 +08:00
fix dispatch threads for a few kernels (#1594)
This commit is contained in:
parent
16ec0556a0
commit
6931f84412
@ -52,13 +52,14 @@ void explicit_gemm_conv_ND_gpu(
|
||||
compute_encoder.set_bytes(conv_params, 2);
|
||||
|
||||
// Launch unfolding kernel
|
||||
int tgp_x = std::min(conv_params.C, 64);
|
||||
size_t tgp_x = std::min(conv_params.C, 64);
|
||||
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
|
||||
int tgp_y = 256 / tgp_x;
|
||||
size_t tgp_y = 256 / tgp_x;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
MTL::Size group_dims = MTL::Size(
|
||||
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
@ -130,13 +131,14 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
compute_encoder.set_bytes(conv_params, 2);
|
||||
|
||||
// Launch unfolding kernel
|
||||
int tgp_x = std::min(conv_params.C, 64);
|
||||
size_t tgp_x = std::min(conv_params.C, 64);
|
||||
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
|
||||
int tgp_y = 256 / tgp_x;
|
||||
size_t tgp_y = 256 / tgp_x;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
||||
MTL::Size group_dims = MTL::Size(
|
||||
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
|
@ -72,8 +72,9 @@ void CustomKernel::eval_gpu(
|
||||
}
|
||||
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
MTL::Size group_dims = MTL::Size(tx, ty, tz);
|
||||
const auto [gx, gy, gz] = grid_;
|
||||
MTL::Size group_dims =
|
||||
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
|
||||
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
|
||||
|
@ -437,8 +437,7 @@ void steel_matmul(
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
|
||||
|
||||
auto group_dims = get_block_dims(N, M, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
@ -955,8 +954,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
MTL::Size grid_dims = MTL::Size(N, M, 1);
|
||||
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
|
||||
|
||||
auto group_dims = get_block_dims(N, M, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
|
@ -273,7 +273,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// organize into grid nkeys x elem_per_key
|
||||
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
MTL::Size group_dims = MTL::Size(1, thread_group_size, 1);
|
||||
auto group_dims = get_block_dims(num_keys, half_size + odd, 1);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(keys, 0);
|
||||
|
Loading…
Reference in New Issue
Block a user