Ensure shape dimensions are within supported integer range (#566) (#704)

* Ensure shape dimensions are within supported integer range (#566)

* fix build

* fix rebase bug

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jack Mousseau 2024-03-25 13:29:45 -07:00 committed by GitHub
parent 479051ce1c
commit 8e686764ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 52 additions and 4 deletions

View File

@ -68,6 +68,22 @@ std::vector<int> broadcast_shapes(
bool is_same_shape(const std::vector<array>& arrays);
/** Returns the shape dimension if it's within allowed range. */
template <typename T>
int check_shape_dim(const T dim) {
constexpr bool is_signed = std::numeric_limits<T>::is_signed;
using U = std::conditional_t<is_signed, ssize_t, size_t>;
constexpr U min = static_cast<U>(std::numeric_limits<int>::min());
constexpr U max = static_cast<U>(std::numeric_limits<int>::max());
if ((is_signed && dim < min) || dim > max) {
throw std::invalid_argument(
"Shape dimension falls outside supported `int` range.");
}
return static_cast<int>(dim);
}
/**
* Returns the axis normalized to be in the range [0, ndim).
* Based on numpy's normalize_axis_index. See

View File

@ -196,7 +196,7 @@ PyScalarT validate_shape(
template <typename T>
void get_shape(T list, std::vector<int>& shape) {
shape.push_back(nb::len(list));
shape.push_back(check_shape_dim(nb::len(list)));
if (shape.back() > 0) {
auto l = list.begin();
if (nb::isinstance<nb::list>(*l)) {
@ -205,7 +205,9 @@ void get_shape(T list, std::vector<int>& shape) {
return get_shape(nb::cast<nb::tuple>(*l), shape);
} else if (nb::isinstance<array>(*l)) {
auto arr = nb::cast<array>(*l);
shape.insert(shape.end(), arr.shape().begin(), arr.shape().end());
for (int i = 0; i < arr.ndim(); i++) {
shape.push_back(check_shape_dim(arr.shape(i)));
}
return;
}
}

View File

@ -4,6 +4,8 @@
#include "python/src/convert.h"
#include "mlx/utils.h"
namespace nanobind {
template <>
struct ndarray_traits<float16_t> {
@ -43,7 +45,7 @@ array nd_array_to_mlx(
// Compute the shape and size
std::vector<int> shape;
for (int i = 0; i < nd_array.ndim(); i++) {
shape.push_back(nd_array.shape(i));
shape.push_back(check_shape_dim(nd_array.shape(i)));
}
auto type = nd_array.dtype();

View File

@ -534,6 +534,14 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(b_npy.dtype, np_dtype)
def test_array_np_shape_dim_check(self):
a_npy = np.empty(2**31, dtype=np.bool_)
with self.assertRaises(ValueError) as e:
mx.array(a_npy)
self.assertEqual(
str(e.exception), "Shape dimension falls outside supported `int` range."
)
def test_dtype_promotion(self):
dtypes_list = [
(mx.bool_, np.bool_),

View File

@ -60,3 +60,23 @@ TEST_CASE("test is same size and shape") {
CHECK_EQ(is_same_shape(tc.a), tc.expected);
}
}
TEST_CASE("test check shape dimension") {
int dim_min = std::numeric_limits<int>::min();
int dim_max = std::numeric_limits<int>::max();
CHECK_EQ(check_shape_dim(-4), -4);
CHECK_EQ(check_shape_dim(0), 0);
CHECK_EQ(check_shape_dim(12), 12);
CHECK_EQ(check_shape_dim(static_cast<ssize_t>(dim_min)), dim_min);
CHECK_EQ(check_shape_dim(static_cast<ssize_t>(dim_max)), dim_max);
CHECK_EQ(check_shape_dim(static_cast<size_t>(0)), 0);
CHECK_EQ(check_shape_dim(static_cast<size_t>(dim_max)), dim_max);
CHECK_THROWS_AS(
check_shape_dim(static_cast<ssize_t>(dim_min) - 1),
std::invalid_argument);
CHECK_THROWS_AS(
check_shape_dim(static_cast<ssize_t>(dim_max) + 1),
std::invalid_argument);
CHECK_THROWS_AS(
check_shape_dim(static_cast<size_t>(dim_max) + 1), std::invalid_argument);
}