Fix get 2d grid dims (#2316)

This commit is contained in:
Angelos Katharopoulos 2025-06-25 13:03:09 -07:00 committed by GitHub
parent 81bb9a2a9e
commit 656ed7f780
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -199,12 +199,15 @@ Dims get_2d_grid_dims_common(
} }
} }
} }
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) { if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
throw std::runtime_error("Unable to safely factor shape."); throw std::runtime_error("Unable to safely factor shape.");
} }
if (grid_y > grid_x) { if (grid_y > grid_x) {
std::swap(grid_x, grid_y); std::swap(grid_x, grid_y);
} }
if (divisor > 1) {
grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
}
return std::make_tuple( return std::make_tuple(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1); static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
} }