Ensure shape dimensions are within supported integer range (#566) (#704)

* Ensure shape dimensions are within supported integer range (#566)

* fix build

* fix rebase bug

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jack Mousseau
2024-03-25 13:29:45 -07:00
committed by GitHub
parent 479051ce1c
commit 8e686764ac
5 changed files with 52 additions and 4 deletions

View File

@@ -196,7 +196,7 @@ PyScalarT validate_shape(
template <typename T>
void get_shape(T list, std::vector<int>& shape) {
shape.push_back(nb::len(list));
shape.push_back(check_shape_dim(nb::len(list)));
if (shape.back() > 0) {
auto l = list.begin();
if (nb::isinstance<nb::list>(*l)) {
@@ -205,7 +205,9 @@ void get_shape(T list, std::vector<int>& shape) {
return get_shape(nb::cast<nb::tuple>(*l), shape);
} else if (nb::isinstance<array>(*l)) {
auto arr = nb::cast<array>(*l);
shape.insert(shape.end(), arr.shape().begin(), arr.shape().end());
for (int i = 0; i < arr.ndim(); i++) {
shape.push_back(check_shape_dim(arr.shape(i)));
}
return;
}
}