diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 457ecb7f7..9766e5e0c 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -199,12 +199,15 @@ Dims get_2d_grid_dims_common( } } } - if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) { + if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { throw std::runtime_error("Unable to safely factor shape."); } if (grid_y > grid_x) { std::swap(grid_x, grid_y); } + if (divisor > 1) { + grid_x = ((grid_x + divisor - 1) / divisor) * divisor; + } return std::make_tuple( static_cast(grid_x), static_cast(grid_y), 1); }