mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix warp underflow
This commit is contained in:
@@ -257,9 +257,7 @@ void row_reduce_simple(
|
|||||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
|
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
warps /= 4;
|
warps /= 4;
|
||||||
if (warps > 32) {
|
warps = std::max(std::min(warps, 32), 1);
|
||||||
warps = 32;
|
|
||||||
}
|
|
||||||
int threads = warps * WARP_SIZE;
|
int threads = warps * WARP_SIZE;
|
||||||
dim3 block(threads, 1, 1);
|
dim3 block(threads, 1, 1);
|
||||||
|
|
||||||
@@ -306,9 +304,7 @@ void row_reduce_looped(
|
|||||||
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
size_t reductions = (args.row_size + N_READS - 1) / N_READS;
|
||||||
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
|
int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
warps /= 4;
|
warps /= 4;
|
||||||
if (warps > 32) {
|
warps = std::max(std::min(warps, 32), 1);
|
||||||
warps = 32;
|
|
||||||
}
|
|
||||||
int threads = warps * WARP_SIZE;
|
int threads = warps * WARP_SIZE;
|
||||||
dim3 block(threads, 1, 1);
|
dim3 block(threads, 1, 1);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user