From 6931f84412d867770ca469c46866d4f43fdfbcf7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 18 Nov 2024 08:35:25 -0800 Subject: [PATCH] fix dispatch threads for a few kernels (#1594) --- mlx/backend/metal/conv.cpp | 14 ++++++++------ mlx/backend/metal/custom_kernel.cpp | 3 ++- mlx/backend/metal/matmul.cpp | 6 ++---- mlx/backend/metal/primitives.cpp | 2 +- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index d5e715e80..fc1649730 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -52,13 +52,14 @@ void explicit_gemm_conv_ND_gpu( compute_encoder.set_bytes(conv_params, 2); // Launch unfolding kernel - int tgp_x = std::min(conv_params.C, 64); + size_t tgp_x = std::min(conv_params.C, 64); tgp_x = 32 * ((tgp_x + 32 - 1) / 32); - int tgp_y = 256 / tgp_x; + size_t tgp_y = 256 / tgp_x; - MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1); MTL::Size grid_dims = MTL::Size( conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); + MTL::Size group_dims = MTL::Size( + std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1); compute_encoder.dispatch_threads(grid_dims, group_dims); @@ -130,13 +131,14 @@ void explicit_gemm_conv_group_ND_gpu( compute_encoder.set_bytes(conv_params, 2); // Launch unfolding kernel - int tgp_x = std::min(conv_params.C, 64); + size_t tgp_x = std::min(conv_params.C, 64); tgp_x = 32 * ((tgp_x + 32 - 1) / 32); - int tgp_y = 256 / tgp_x; + size_t tgp_y = 256 / tgp_x; - MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1); MTL::Size grid_dims = MTL::Size( conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); + MTL::Size group_dims = MTL::Size( + std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1); compute_encoder.dispatch_threads(grid_dims, group_dims); diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 8e0fb1173..66f75a65f 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -72,8 +72,9 @@ void CustomKernel::eval_gpu( } const auto [tx, ty, tz] = threadgroup_; - MTL::Size group_dims = MTL::Size(tx, ty, tz); const auto [gx, gy, gz] = grid_; + MTL::Size group_dims = + MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz)); MTL::Size grid_dims = MTL::Size(gx, gy, gz); compute_encoder.dispatch_threads(grid_dims, group_dims); diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index af3e85ec8..7d2ccd87f 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -437,8 +437,7 @@ void steel_matmul( // Launch enough thread groups for each output MTL::Size grid_dims = MTL::Size(N, M, 1); - MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1); - + auto group_dims = get_block_dims(N, M, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } @@ -955,8 +954,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // Launch enough thread groups for each output MTL::Size grid_dims = MTL::Size(N, M, 1); - MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1); - + auto group_dims = get_block_dims(N, M, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index da176ffd1..732b40edf 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -273,7 +273,7 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { // organize into grid nkeys x elem_per_key MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - MTL::Size group_dims = MTL::Size(1, thread_group_size, 1); + auto group_dims = get_block_dims(num_keys, half_size + odd, 1); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(keys, 0);