Make threadgroup size less or equal to grid size

This commit is contained in:
Angelos Katharopoulos 2025-08-19 01:13:20 -07:00
parent 432c02dabc
commit 39dbd92df5

View File

@ -352,7 +352,7 @@ void CustomKernel::eval_gpu(
// Make the grid
const auto [tx, ty, tz] = threadgroup_;
const auto [gx, gy, gz] = grid_;
dim3 block(tx, ty, tz);
dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
// Call the kernel