mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
* Ensure shape dimensions are within supported integer range (#566) * fix build * fix rebase bug --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -59,4 +59,24 @@ TEST_CASE("test is same size and shape") {
|
||||
for (const auto& tc : testCases) {
|
||||
CHECK_EQ(is_same_shape(tc.a), tc.expected);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test check shape dimension") {
|
||||
int dim_min = std::numeric_limits<int>::min();
|
||||
int dim_max = std::numeric_limits<int>::max();
|
||||
CHECK_EQ(check_shape_dim(-4), -4);
|
||||
CHECK_EQ(check_shape_dim(0), 0);
|
||||
CHECK_EQ(check_shape_dim(12), 12);
|
||||
CHECK_EQ(check_shape_dim(static_cast<ssize_t>(dim_min)), dim_min);
|
||||
CHECK_EQ(check_shape_dim(static_cast<ssize_t>(dim_max)), dim_max);
|
||||
CHECK_EQ(check_shape_dim(static_cast<size_t>(0)), 0);
|
||||
CHECK_EQ(check_shape_dim(static_cast<size_t>(dim_max)), dim_max);
|
||||
CHECK_THROWS_AS(
|
||||
check_shape_dim(static_cast<ssize_t>(dim_min) - 1),
|
||||
std::invalid_argument);
|
||||
CHECK_THROWS_AS(
|
||||
check_shape_dim(static_cast<ssize_t>(dim_max) + 1),
|
||||
std::invalid_argument);
|
||||
CHECK_THROWS_AS(
|
||||
check_shape_dim(static_cast<size_t>(dim_max) + 1), std::invalid_argument);
|
||||
}
|
||||
|
Reference in New Issue
Block a user