mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 00:36:32 +08:00
Make threadgroup size less or equal to grid size
This commit is contained in:
parent
432c02dabc
commit
39dbd92df5
@ -352,7 +352,7 @@ void CustomKernel::eval_gpu(
|
||||
// Make the grid
|
||||
const auto [tx, ty, tz] = threadgroup_;
|
||||
const auto [gx, gy, gz] = grid_;
|
||||
dim3 block(tx, ty, tz);
|
||||
dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
|
||||
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
|
||||
|
||||
// Call the kernel
|
||||
|
Loading…
Reference in New Issue
Block a user