Ensure shape dimensions are within supported integer range (#566) (#704)

* Ensure shape dimensions are within supported integer range (#566)

* fix build

* fix rebase bug

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jack Mousseau
2024-03-25 13:29:45 -07:00
committed by GitHub
parent 479051ce1c
commit 8e686764ac
5 changed files with 52 additions and 4 deletions

View File

@@ -68,6 +68,22 @@ std::vector<int> broadcast_shapes(
bool is_same_shape(const std::vector<array>& arrays);
/** Returns the shape dimension if it's within allowed range. */
template <typename T>
int check_shape_dim(const T dim) {
constexpr bool is_signed = std::numeric_limits<T>::is_signed;
using U = std::conditional_t<is_signed, ssize_t, size_t>;
constexpr U min = static_cast<U>(std::numeric_limits<int>::min());
constexpr U max = static_cast<U>(std::numeric_limits<int>::max());
if ((is_signed && dim < min) || dim > max) {
throw std::invalid_argument(
"Shape dimension falls outside supported `int` range.");
}
return static_cast<int>(dim);
}
/**
* Returns the axis normalized to be in the range [0, ndim).
* Based on numpy's normalize_axis_index. See