Ensure shape dimensions are within supported integer range (#566) (#704)

* Ensure shape dimensions are within supported integer range (#566)

* fix build

* fix rebase bug

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jack Mousseau
2024-03-25 13:29:45 -07:00
committed by GitHub
parent 479051ce1c
commit 8e686764ac
5 changed files with 52 additions and 4 deletions

View File

@@ -59,4 +59,24 @@ TEST_CASE("test is same size and shape") {
for (const auto& tc : testCases) {
CHECK_EQ(is_same_shape(tc.a), tc.expected);
}
}
}
TEST_CASE("test check shape dimension") {
int dim_min = std::numeric_limits<int>::min();
int dim_max = std::numeric_limits<int>::max();
CHECK_EQ(check_shape_dim(-4), -4);
CHECK_EQ(check_shape_dim(0), 0);
CHECK_EQ(check_shape_dim(12), 12);
CHECK_EQ(check_shape_dim(static_cast<ssize_t>(dim_min)), dim_min);
CHECK_EQ(check_shape_dim(static_cast<ssize_t>(dim_max)), dim_max);
CHECK_EQ(check_shape_dim(static_cast<size_t>(0)), 0);
CHECK_EQ(check_shape_dim(static_cast<size_t>(dim_max)), dim_max);
CHECK_THROWS_AS(
check_shape_dim(static_cast<ssize_t>(dim_min) - 1),
std::invalid_argument);
CHECK_THROWS_AS(
check_shape_dim(static_cast<ssize_t>(dim_max) + 1),
std::invalid_argument);
CHECK_THROWS_AS(
check_shape_dim(static_cast<size_t>(dim_max) + 1), std::invalid_argument);
}