From 656ed7f7808266ae7923a010a6b1f5d166cf6256 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 25 Jun 2025 13:03:09 -0700 Subject: [PATCH] Fix get 2d grid dims (#2316) --- mlx/backend/common/utils.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 457ecb7f7..9766e5e0c 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -199,12 +199,15 @@ Dims get_2d_grid_dims_common( } } } - if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) { + if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { throw std::runtime_error("Unable to safely factor shape."); } if (grid_y > grid_x) { std::swap(grid_x, grid_y); } + if (divisor > 1) { + grid_x = ((grid_x + divisor - 1) / divisor) * divisor; + } return std::make_tuple( static_cast(grid_x), static_cast(grid_y), 1); }