From 8e686764ac93cc6d0f9591b9a20eae6a9d10b27c Mon Sep 17 00:00:00 2001 From: Jack Mousseau Date: Mon, 25 Mar 2024 13:29:45 -0700 Subject: [PATCH] Ensure shape dimensions are within supported integer range (#566) (#704) * Ensure shape dimensions are within supported integer range (#566) * fix build * fix rebase bug --------- Co-authored-by: Awni Hannun --- mlx/utils.h | 16 ++++++++++++++++ python/src/array.cpp | 6 ++++-- python/src/convert.cpp | 4 +++- python/tests/test_array.py | 8 ++++++++ tests/utils_tests.cpp | 22 +++++++++++++++++++++- 5 files changed, 52 insertions(+), 4 deletions(-) diff --git a/mlx/utils.h b/mlx/utils.h index 8a5f2ccd9..1fe7edc05 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -68,6 +68,22 @@ std::vector broadcast_shapes( bool is_same_shape(const std::vector& arrays); +/** Returns the shape dimension if it's within allowed range. */ +template +int check_shape_dim(const T dim) { + constexpr bool is_signed = std::numeric_limits::is_signed; + using U = std::conditional_t; + constexpr U min = static_cast(std::numeric_limits::min()); + constexpr U max = static_cast(std::numeric_limits::max()); + + if ((is_signed && dim < min) || dim > max) { + throw std::invalid_argument( + "Shape dimension falls outside supported `int` range."); + } + + return static_cast(dim); +} + /** * Returns the axis normalized to be in the range [0, ndim). * Based on numpy's normalize_axis_index. See diff --git a/python/src/array.cpp b/python/src/array.cpp index d8a04d175..de57b701e 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -196,7 +196,7 @@ PyScalarT validate_shape( template void get_shape(T list, std::vector& shape) { - shape.push_back(nb::len(list)); + shape.push_back(check_shape_dim(nb::len(list))); if (shape.back() > 0) { auto l = list.begin(); if (nb::isinstance(*l)) { @@ -205,7 +205,9 @@ void get_shape(T list, std::vector& shape) { return get_shape(nb::cast(*l), shape); } else if (nb::isinstance(*l)) { auto arr = nb::cast(*l); - shape.insert(shape.end(), arr.shape().begin(), arr.shape().end()); + for (int i = 0; i < arr.ndim(); i++) { + shape.push_back(check_shape_dim(arr.shape(i))); + } return; } } diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 740dbb664..a27e6313e 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -4,6 +4,8 @@ #include "python/src/convert.h" +#include "mlx/utils.h" + namespace nanobind { template <> struct ndarray_traits { @@ -43,7 +45,7 @@ array nd_array_to_mlx( // Compute the shape and size std::vector shape; for (int i = 0; i < nd_array.ndim(); i++) { - shape.push_back(nd_array.shape(i)); + shape.push_back(check_shape_dim(nd_array.shape(i))); } auto type = nd_array.dtype(); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 9b06d3e57..03a341f9c 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -534,6 +534,14 @@ class TestArray(mlx_tests.MLXTestCase): self.assertEqual(b_npy.dtype, np_dtype) + def test_array_np_shape_dim_check(self): + a_npy = np.empty(2**31, dtype=np.bool_) + with self.assertRaises(ValueError) as e: + mx.array(a_npy) + self.assertEqual( + str(e.exception), "Shape dimension falls outside supported `int` range." + ) + def test_dtype_promotion(self): dtypes_list = [ (mx.bool_, np.bool_), diff --git a/tests/utils_tests.cpp b/tests/utils_tests.cpp index e7bb35d21..f97553713 100644 --- a/tests/utils_tests.cpp +++ b/tests/utils_tests.cpp @@ -59,4 +59,24 @@ TEST_CASE("test is same size and shape") { for (const auto& tc : testCases) { CHECK_EQ(is_same_shape(tc.a), tc.expected); } -} \ No newline at end of file +} + +TEST_CASE("test check shape dimension") { + int dim_min = std::numeric_limits::min(); + int dim_max = std::numeric_limits::max(); + CHECK_EQ(check_shape_dim(-4), -4); + CHECK_EQ(check_shape_dim(0), 0); + CHECK_EQ(check_shape_dim(12), 12); + CHECK_EQ(check_shape_dim(static_cast(dim_min)), dim_min); + CHECK_EQ(check_shape_dim(static_cast(dim_max)), dim_max); + CHECK_EQ(check_shape_dim(static_cast(0)), 0); + CHECK_EQ(check_shape_dim(static_cast(dim_max)), dim_max); + CHECK_THROWS_AS( + check_shape_dim(static_cast(dim_min) - 1), + std::invalid_argument); + CHECK_THROWS_AS( + check_shape_dim(static_cast(dim_max) + 1), + std::invalid_argument); + CHECK_THROWS_AS( + check_shape_dim(static_cast(dim_max) + 1), std::invalid_argument); +}