diff --git a/mlx/utils.h b/mlx/utils.h index eb194b71ec..5e0ef22222 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -70,7 +70,7 @@ bool is_same_shape(const std::vector& arrays); template int check_shape_dim(const T dim) { constexpr bool is_signed = std::numeric_limits::is_signed; - using U = std::conditional_t; + using U = std::conditional_t; constexpr U min = static_cast(std::numeric_limits::min()); constexpr U max = static_cast(std::numeric_limits::max()); diff --git a/tests/utils_tests.cpp b/tests/utils_tests.cpp index f975537139..3f2e3f814f 100644 --- a/tests/utils_tests.cpp +++ b/tests/utils_tests.cpp @@ -67,15 +67,15 @@ TEST_CASE("test check shape dimension") { CHECK_EQ(check_shape_dim(-4), -4); CHECK_EQ(check_shape_dim(0), 0); CHECK_EQ(check_shape_dim(12), 12); - CHECK_EQ(check_shape_dim(static_cast(dim_min)), dim_min); - CHECK_EQ(check_shape_dim(static_cast(dim_max)), dim_max); + CHECK_EQ(check_shape_dim(static_cast(dim_min)), dim_min); + CHECK_EQ(check_shape_dim(static_cast(dim_max)), dim_max); CHECK_EQ(check_shape_dim(static_cast(0)), 0); CHECK_EQ(check_shape_dim(static_cast(dim_max)), dim_max); CHECK_THROWS_AS( - check_shape_dim(static_cast(dim_min) - 1), + check_shape_dim(static_cast(dim_min) - 1), std::invalid_argument); CHECK_THROWS_AS( - check_shape_dim(static_cast(dim_max) + 1), + check_shape_dim(static_cast(dim_max) + 1), std::invalid_argument); CHECK_THROWS_AS( check_shape_dim(static_cast(dim_max) + 1), std::invalid_argument);