mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
* Ensure shape dimensions are within supported integer range (#566) * fix build * fix rebase bug --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
479051ce1c
commit
8e686764ac
16
mlx/utils.h
16
mlx/utils.h
@ -68,6 +68,22 @@ std::vector<int> broadcast_shapes(
|
|||||||
|
|
||||||
bool is_same_shape(const std::vector<array>& arrays);
|
bool is_same_shape(const std::vector<array>& arrays);
|
||||||
|
|
||||||
|
/** Returns the shape dimension if it's within allowed range. */
|
||||||
|
template <typename T>
|
||||||
|
int check_shape_dim(const T dim) {
|
||||||
|
constexpr bool is_signed = std::numeric_limits<T>::is_signed;
|
||||||
|
using U = std::conditional_t<is_signed, ssize_t, size_t>;
|
||||||
|
constexpr U min = static_cast<U>(std::numeric_limits<int>::min());
|
||||||
|
constexpr U max = static_cast<U>(std::numeric_limits<int>::max());
|
||||||
|
|
||||||
|
if ((is_signed && dim < min) || dim > max) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Shape dimension falls outside supported `int` range.");
|
||||||
|
}
|
||||||
|
|
||||||
|
return static_cast<int>(dim);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the axis normalized to be in the range [0, ndim).
|
* Returns the axis normalized to be in the range [0, ndim).
|
||||||
* Based on numpy's normalize_axis_index. See
|
* Based on numpy's normalize_axis_index. See
|
||||||
|
@ -196,7 +196,7 @@ PyScalarT validate_shape(
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void get_shape(T list, std::vector<int>& shape) {
|
void get_shape(T list, std::vector<int>& shape) {
|
||||||
shape.push_back(nb::len(list));
|
shape.push_back(check_shape_dim(nb::len(list)));
|
||||||
if (shape.back() > 0) {
|
if (shape.back() > 0) {
|
||||||
auto l = list.begin();
|
auto l = list.begin();
|
||||||
if (nb::isinstance<nb::list>(*l)) {
|
if (nb::isinstance<nb::list>(*l)) {
|
||||||
@ -205,7 +205,9 @@ void get_shape(T list, std::vector<int>& shape) {
|
|||||||
return get_shape(nb::cast<nb::tuple>(*l), shape);
|
return get_shape(nb::cast<nb::tuple>(*l), shape);
|
||||||
} else if (nb::isinstance<array>(*l)) {
|
} else if (nb::isinstance<array>(*l)) {
|
||||||
auto arr = nb::cast<array>(*l);
|
auto arr = nb::cast<array>(*l);
|
||||||
shape.insert(shape.end(), arr.shape().begin(), arr.shape().end());
|
for (int i = 0; i < arr.ndim(); i++) {
|
||||||
|
shape.push_back(check_shape_dim(arr.shape(i)));
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,8 @@
|
|||||||
|
|
||||||
#include "python/src/convert.h"
|
#include "python/src/convert.h"
|
||||||
|
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace nanobind {
|
namespace nanobind {
|
||||||
template <>
|
template <>
|
||||||
struct ndarray_traits<float16_t> {
|
struct ndarray_traits<float16_t> {
|
||||||
@ -43,7 +45,7 @@ array nd_array_to_mlx(
|
|||||||
// Compute the shape and size
|
// Compute the shape and size
|
||||||
std::vector<int> shape;
|
std::vector<int> shape;
|
||||||
for (int i = 0; i < nd_array.ndim(); i++) {
|
for (int i = 0; i < nd_array.ndim(); i++) {
|
||||||
shape.push_back(nd_array.shape(i));
|
shape.push_back(check_shape_dim(nd_array.shape(i)));
|
||||||
}
|
}
|
||||||
auto type = nd_array.dtype();
|
auto type = nd_array.dtype();
|
||||||
|
|
||||||
|
@ -534,6 +534,14 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertEqual(b_npy.dtype, np_dtype)
|
self.assertEqual(b_npy.dtype, np_dtype)
|
||||||
|
|
||||||
|
def test_array_np_shape_dim_check(self):
|
||||||
|
a_npy = np.empty(2**31, dtype=np.bool_)
|
||||||
|
with self.assertRaises(ValueError) as e:
|
||||||
|
mx.array(a_npy)
|
||||||
|
self.assertEqual(
|
||||||
|
str(e.exception), "Shape dimension falls outside supported `int` range."
|
||||||
|
)
|
||||||
|
|
||||||
def test_dtype_promotion(self):
|
def test_dtype_promotion(self):
|
||||||
dtypes_list = [
|
dtypes_list = [
|
||||||
(mx.bool_, np.bool_),
|
(mx.bool_, np.bool_),
|
||||||
|
@ -60,3 +60,23 @@ TEST_CASE("test is same size and shape") {
|
|||||||
CHECK_EQ(is_same_shape(tc.a), tc.expected);
|
CHECK_EQ(is_same_shape(tc.a), tc.expected);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test check shape dimension") {
|
||||||
|
int dim_min = std::numeric_limits<int>::min();
|
||||||
|
int dim_max = std::numeric_limits<int>::max();
|
||||||
|
CHECK_EQ(check_shape_dim(-4), -4);
|
||||||
|
CHECK_EQ(check_shape_dim(0), 0);
|
||||||
|
CHECK_EQ(check_shape_dim(12), 12);
|
||||||
|
CHECK_EQ(check_shape_dim(static_cast<ssize_t>(dim_min)), dim_min);
|
||||||
|
CHECK_EQ(check_shape_dim(static_cast<ssize_t>(dim_max)), dim_max);
|
||||||
|
CHECK_EQ(check_shape_dim(static_cast<size_t>(0)), 0);
|
||||||
|
CHECK_EQ(check_shape_dim(static_cast<size_t>(dim_max)), dim_max);
|
||||||
|
CHECK_THROWS_AS(
|
||||||
|
check_shape_dim(static_cast<ssize_t>(dim_min) - 1),
|
||||||
|
std::invalid_argument);
|
||||||
|
CHECK_THROWS_AS(
|
||||||
|
check_shape_dim(static_cast<ssize_t>(dim_max) + 1),
|
||||||
|
std::invalid_argument);
|
||||||
|
CHECK_THROWS_AS(
|
||||||
|
check_shape_dim(static_cast<size_t>(dim_max) + 1), std::invalid_argument);
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user