From 39dbd92df546488a48c36cd957d4cb41320e2b76 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 19 Aug 2025 01:13:20 -0700 Subject: [PATCH] Make threadgroup size less or equal to grid size --- mlx/backend/cuda/custom_kernel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp index 2f8be9221..5ce9d3097 100644 --- a/mlx/backend/cuda/custom_kernel.cpp +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -352,7 +352,7 @@ void CustomKernel::eval_gpu( // Make the grid const auto [tx, ty, tz] = threadgroup_; const auto [gx, gy, gz] = grid_; - dim3 block(tx, ty, tz); + dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz)); dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz); // Call the kernel