Ensure shape dimensions are within supported integer range (#566) (#704)

* Ensure shape dimensions are within supported integer range (#566)

* fix build

* fix rebase bug

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jack Mousseau
2024-03-25 13:29:45 -07:00
committed by GitHub
parent 479051ce1c
commit 8e686764ac
5 changed files with 52 additions and 4 deletions

View File

@@ -196,7 +196,7 @@ PyScalarT validate_shape(
template <typename T>
void get_shape(T list, std::vector<int>& shape) {
shape.push_back(nb::len(list));
shape.push_back(check_shape_dim(nb::len(list)));
if (shape.back() > 0) {
auto l = list.begin();
if (nb::isinstance<nb::list>(*l)) {
@@ -205,7 +205,9 @@ void get_shape(T list, std::vector<int>& shape) {
return get_shape(nb::cast<nb::tuple>(*l), shape);
} else if (nb::isinstance<array>(*l)) {
auto arr = nb::cast<array>(*l);
shape.insert(shape.end(), arr.shape().begin(), arr.shape().end());
for (int i = 0; i < arr.ndim(); i++) {
shape.push_back(check_shape_dim(arr.shape(i)));
}
return;
}
}

View File

@@ -4,6 +4,8 @@
#include "python/src/convert.h"
#include "mlx/utils.h"
namespace nanobind {
template <>
struct ndarray_traits<float16_t> {
@@ -43,7 +45,7 @@ array nd_array_to_mlx(
// Compute the shape and size
std::vector<int> shape;
for (int i = 0; i < nd_array.ndim(); i++) {
shape.push_back(nd_array.shape(i));
shape.push_back(check_shape_dim(nd_array.shape(i)));
}
auto type = nd_array.dtype();

View File

@@ -534,6 +534,14 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(b_npy.dtype, np_dtype)
def test_array_np_shape_dim_check(self):
a_npy = np.empty(2**31, dtype=np.bool_)
with self.assertRaises(ValueError) as e:
mx.array(a_npy)
self.assertEqual(
str(e.exception), "Shape dimension falls outside supported `int` range."
)
def test_dtype_promotion(self):
dtypes_list = [
(mx.bool_, np.bool_),