From 5580b47291caed127e54a427a265e7776add2997 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 27 Mar 2025 19:54:56 -0700 Subject: [PATCH] iinfo and scalar overflow detection (#2009) --- mlx/utils.cpp | 39 ++++++++++++++++++++++++++++++++++++++ mlx/utils.h | 8 ++++++++ python/src/array.cpp | 24 +++++++++++++++++++++++ python/src/utils.cpp | 11 +++++++++++ python/tests/test_array.py | 20 +++++++++++++++++++ tests/utils_tests.cpp | 10 ++++++++++ 6 files changed, 112 insertions(+) diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 9168b34c8..5197e516f 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -380,4 +380,43 @@ finfo::finfo(Dtype dtype) : dtype(dtype) { } } +template +void set_iinfo_limits(int64_t& min, uint64_t& max) { + min = std::numeric_limits::min(); + max = std::numeric_limits::max(); +} + +iinfo::iinfo(Dtype dtype) : dtype(dtype) { + switch (dtype) { + case int8: + set_iinfo_limits(min, max); + break; + case uint8: + set_iinfo_limits(min, max); + break; + case int16: + set_iinfo_limits(min, max); + break; + case uint16: + set_iinfo_limits(min, max); + break; + case int32: + set_iinfo_limits(min, max); + break; + case uint32: + set_iinfo_limits(min, max); + break; + case int64: + set_iinfo_limits(min, max); + break; + case uint64: + set_iinfo_limits(min, max); + break; + default: + std::ostringstream msg; + msg << "[iinfo] dtype " << dtype << " is not integral."; + throw std::invalid_argument(msg.str()); + } +} + } // namespace mlx::core diff --git a/mlx/utils.h b/mlx/utils.h index 0b5ce54a1..19241e4c6 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -67,6 +67,14 @@ struct finfo { double max; }; +/** Holds information about integral types. */ +struct iinfo { + explicit iinfo(Dtype dtype); + Dtype dtype; + int64_t min; + uint64_t max; +}; + /** The type from promoting the arrays' types with one another. */ inline Dtype result_type(const array& a, const array& b) { return promote_types(a.dtype(), b.dtype()); diff --git a/python/src/array.cpp b/python/src/array.cpp index 375f9a3ec..e380f2652 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -206,6 +206,30 @@ void init_array(nb::module_& m) { return os.str(); }); + nb::class_( + m, + "iinfo", + R"pbdoc( + Get information on integer types. + )pbdoc") + .def(nb::init()) + .def_ro( + "min", + &mx::iinfo::min, + R"pbdoc(The smallest representable number.)pbdoc") + .def_ro( + "max", + &mx::iinfo::max, + R"pbdoc(The largest representable number.)pbdoc") + .def_ro("dtype", &mx::iinfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc") + .def("__repr__", [](const mx::iinfo& i) { + std::ostringstream os; + os << "iinfo(" + << "min=" << i.min << ", max=" << i.max << ", dtype=" << i.dtype + << ")"; + return os.str(); + }); + nb::class_( m, "ArrayAt", diff --git a/python/src/utils.cpp b/python/src/utils.cpp index e6ca346dc..08f78bdf4 100644 --- a/python/src/utils.cpp +++ b/python/src/utils.cpp @@ -2,6 +2,7 @@ #include "python/src/utils.h" #include "mlx/ops.h" +#include "mlx/utils.h" #include "python/src/convert.h" mx::array to_array( @@ -16,6 +17,16 @@ mx::array to_array( ? mx::int64 : mx::int32; auto out_t = dtype.value_or(default_type); + if (mx::issubdtype(out_t, mx::integer) && out_t.size() < 8) { + auto info = mx::iinfo(out_t); + if (val < info.min || val > static_cast(info.max)) { + std::ostringstream msg; + msg << "Converting " << val << " to " << out_t + << " would result in overflow."; + throw std::invalid_argument(msg.str()); + } + } + // bool_ is an exception and is always promoted return mx::array(val, (out_t == mx::bool_) ? mx::int32 : out_t); } else if (auto pv = std::get_if(&v); pv) { diff --git a/python/tests/test_array.py b/python/tests/test_array.py index f22d7ced0..67dc7c84b 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -109,6 +109,18 @@ class TestDtypes(mlx_tests.MLXTestCase): self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max) self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16) + def test_iinfo(self): + with self.assertRaises(ValueError): + mx.iinfo(mx.float32) + + self.assertEqual(mx.iinfo(mx.int32).min, np.iinfo(np.int32).min) + self.assertEqual(mx.iinfo(mx.int32).max, np.iinfo(np.int32).max) + self.assertEqual(mx.iinfo(mx.int32).dtype, mx.int32) + + self.assertEqual(mx.iinfo(mx.uint32).min, np.iinfo(np.uint32).min) + self.assertEqual(mx.iinfo(mx.uint32).max, np.iinfo(np.uint32).max) + self.assertEqual(mx.iinfo(mx.int8).dtype, mx.int8) + class TestEquality(mlx_tests.MLXTestCase): def test_array_eq_array(self): @@ -1999,6 +2011,14 @@ class TestArray(mlx_tests.MLXTestCase): used = get_mem() self.assertEqual(expected, used) + def test_scalar_integer_conversion_overflow(self): + y = mx.array(2000000000, dtype=mx.int32) + x = 3000000000 + with self.assertRaises(ValueError): + y + x + with self.assertRaises(ValueError): + mx.add(y, x) + if __name__ == "__main__": unittest.main() diff --git a/tests/utils_tests.cpp b/tests/utils_tests.cpp index a17f12e33..88c3e7b37 100644 --- a/tests/utils_tests.cpp +++ b/tests/utils_tests.cpp @@ -55,3 +55,13 @@ TEST_CASE("test finfo") { CHECK_EQ(finfo(float16).min, -65504); CHECK_EQ(finfo(float16).max, 65504); } + +TEST_CASE("test iinfo") { + CHECK_EQ(iinfo(int8).dtype, int8); + CHECK_EQ(iinfo(int64).dtype, int64); + CHECK_EQ(iinfo(int64).max, std::numeric_limits::max()); + CHECK_EQ(iinfo(uint64).max, std::numeric_limits::max()); + CHECK_EQ(iinfo(uint64).max, std::numeric_limits::max()); + CHECK_EQ(iinfo(uint64).min, 0); + CHECK_EQ(iinfo(int64).min, std::numeric_limits::min()); +}