mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix warp underflow
This commit is contained in:
@@ -257,9 +257,7 @@ void row_reduce_simple(
|
||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
|
||||
warps /= 4;
|
||||
if (warps > 32) {
|
||||
warps = 32;
|
||||
}
|
||||
warps = std::max(std::min(warps, 32), 1);
|
||||
int threads = warps * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
|
||||
@@ -306,9 +304,7 @@ void row_reduce_looped(
|
||||
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
|
||||
warps /= 4;
|
||||
if (warps > 32) {
|
||||
warps = 32;
|
||||
}
|
||||
warps = std::max(std::min(warps, 32), 1);
|
||||
int threads = warps * WARP_SIZE;
|
||||
dim3 block(threads, 1, 1);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user