diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp index 2f8be9221..5ce9d3097 100644 --- a/mlx/backend/cuda/custom_kernel.cpp +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -352,7 +352,7 @@ void CustomKernel::eval_gpu( // Make the grid const auto [tx, ty, tz] = threadgroup_; const auto [gx, gy, gz] = grid_; - dim3 block(tx, ty, tz); + dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz)); dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz); // Call the kernel