mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-29 07:03:10 +08:00
Make threadgroup size less or equal to grid size
This commit is contained in:
parent
432c02dabc
commit
39dbd92df5
@ -352,7 +352,7 @@ void CustomKernel::eval_gpu(
|
|||||||
// Make the grid
|
// Make the grid
|
||||||
const auto [tx, ty, tz] = threadgroup_;
|
const auto [tx, ty, tz] = threadgroup_;
|
||||||
const auto [gx, gy, gz] = grid_;
|
const auto [gx, gy, gz] = grid_;
|
||||||
dim3 block(tx, ty, tz);
|
dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
|
||||||
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
|
dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
|
||||||
|
|
||||||
// Call the kernel
|
// Call the kernel
|
||||||
|
Loading…
Reference in New Issue
Block a user