iinfo and scalar overflow detection (#2009)

This commit is contained in:
Awni Hannun
2025-03-27 19:54:56 -07:00
committed by GitHub
parent bc62932984
commit 5580b47291
6 changed files with 112 additions and 0 deletions

View File

@@ -206,6 +206,30 @@ void init_array(nb::module_& m) {
return os.str();
});
nb::class_<mx::iinfo>(
m,
"iinfo",
R"pbdoc(
Get information on integer types.
)pbdoc")
.def(nb::init<mx::Dtype>())
.def_ro(
"min",
&mx::iinfo::min,
R"pbdoc(The smallest representable number.)pbdoc")
.def_ro(
"max",
&mx::iinfo::max,
R"pbdoc(The largest representable number.)pbdoc")
.def_ro("dtype", &mx::iinfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc")
.def("__repr__", [](const mx::iinfo& i) {
std::ostringstream os;
os << "iinfo("
<< "min=" << i.min << ", max=" << i.max << ", dtype=" << i.dtype
<< ")";
return os.str();
});
nb::class_<ArrayAt>(
m,
"ArrayAt",

View File

@@ -2,6 +2,7 @@
#include "python/src/utils.h"
#include "mlx/ops.h"
#include "mlx/utils.h"
#include "python/src/convert.h"
mx::array to_array(
@@ -16,6 +17,16 @@ mx::array to_array(
? mx::int64
: mx::int32;
auto out_t = dtype.value_or(default_type);
if (mx::issubdtype(out_t, mx::integer) && out_t.size() < 8) {
auto info = mx::iinfo(out_t);
if (val < info.min || val > static_cast<int64_t>(info.max)) {
std::ostringstream msg;
msg << "Converting " << val << " to " << out_t
<< " would result in overflow.";
throw std::invalid_argument(msg.str());
}
}
// bool_ is an exception and is always promoted
return mx::array(val, (out_t == mx::bool_) ? mx::int32 : out_t);
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {

View File

@@ -109,6 +109,18 @@ class TestDtypes(mlx_tests.MLXTestCase):
self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max)
self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16)
def test_iinfo(self):
with self.assertRaises(ValueError):
mx.iinfo(mx.float32)
self.assertEqual(mx.iinfo(mx.int32).min, np.iinfo(np.int32).min)
self.assertEqual(mx.iinfo(mx.int32).max, np.iinfo(np.int32).max)
self.assertEqual(mx.iinfo(mx.int32).dtype, mx.int32)
self.assertEqual(mx.iinfo(mx.uint32).min, np.iinfo(np.uint32).min)
self.assertEqual(mx.iinfo(mx.uint32).max, np.iinfo(np.uint32).max)
self.assertEqual(mx.iinfo(mx.int8).dtype, mx.int8)
class TestEquality(mlx_tests.MLXTestCase):
def test_array_eq_array(self):
@@ -1999,6 +2011,14 @@ class TestArray(mlx_tests.MLXTestCase):
used = get_mem()
self.assertEqual(expected, used)
def test_scalar_integer_conversion_overflow(self):
y = mx.array(2000000000, dtype=mx.int32)
x = 3000000000
with self.assertRaises(ValueError):
y + x
with self.assertRaises(ValueError):
mx.add(y, x)
if __name__ == "__main__":
unittest.main()