mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
iinfo and scalar overflow detection (#2009)
This commit is contained in:
parent
bc62932984
commit
5580b47291
@ -380,4 +380,43 @@ finfo::finfo(Dtype dtype) : dtype(dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_iinfo_limits(int64_t& min, uint64_t& max) {
|
||||
min = std::numeric_limits<T>::min();
|
||||
max = std::numeric_limits<T>::max();
|
||||
}
|
||||
|
||||
iinfo::iinfo(Dtype dtype) : dtype(dtype) {
|
||||
switch (dtype) {
|
||||
case int8:
|
||||
set_iinfo_limits<int8_t>(min, max);
|
||||
break;
|
||||
case uint8:
|
||||
set_iinfo_limits<uint8_t>(min, max);
|
||||
break;
|
||||
case int16:
|
||||
set_iinfo_limits<int16_t>(min, max);
|
||||
break;
|
||||
case uint16:
|
||||
set_iinfo_limits<uint16_t>(min, max);
|
||||
break;
|
||||
case int32:
|
||||
set_iinfo_limits<int32_t>(min, max);
|
||||
break;
|
||||
case uint32:
|
||||
set_iinfo_limits<uint32_t>(min, max);
|
||||
break;
|
||||
case int64:
|
||||
set_iinfo_limits<int64_t>(min, max);
|
||||
break;
|
||||
case uint64:
|
||||
set_iinfo_limits<uint64_t>(min, max);
|
||||
break;
|
||||
default:
|
||||
std::ostringstream msg;
|
||||
msg << "[iinfo] dtype " << dtype << " is not integral.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -67,6 +67,14 @@ struct finfo {
|
||||
double max;
|
||||
};
|
||||
|
||||
/** Holds information about integral types. */
|
||||
struct iinfo {
|
||||
explicit iinfo(Dtype dtype);
|
||||
Dtype dtype;
|
||||
int64_t min;
|
||||
uint64_t max;
|
||||
};
|
||||
|
||||
/** The type from promoting the arrays' types with one another. */
|
||||
inline Dtype result_type(const array& a, const array& b) {
|
||||
return promote_types(a.dtype(), b.dtype());
|
||||
|
@ -206,6 +206,30 @@ void init_array(nb::module_& m) {
|
||||
return os.str();
|
||||
});
|
||||
|
||||
nb::class_<mx::iinfo>(
|
||||
m,
|
||||
"iinfo",
|
||||
R"pbdoc(
|
||||
Get information on integer types.
|
||||
)pbdoc")
|
||||
.def(nb::init<mx::Dtype>())
|
||||
.def_ro(
|
||||
"min",
|
||||
&mx::iinfo::min,
|
||||
R"pbdoc(The smallest representable number.)pbdoc")
|
||||
.def_ro(
|
||||
"max",
|
||||
&mx::iinfo::max,
|
||||
R"pbdoc(The largest representable number.)pbdoc")
|
||||
.def_ro("dtype", &mx::iinfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc")
|
||||
.def("__repr__", [](const mx::iinfo& i) {
|
||||
std::ostringstream os;
|
||||
os << "iinfo("
|
||||
<< "min=" << i.min << ", max=" << i.max << ", dtype=" << i.dtype
|
||||
<< ")";
|
||||
return os.str();
|
||||
});
|
||||
|
||||
nb::class_<ArrayAt>(
|
||||
m,
|
||||
"ArrayAt",
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "python/src/utils.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/convert.h"
|
||||
|
||||
mx::array to_array(
|
||||
@ -16,6 +17,16 @@ mx::array to_array(
|
||||
? mx::int64
|
||||
: mx::int32;
|
||||
auto out_t = dtype.value_or(default_type);
|
||||
if (mx::issubdtype(out_t, mx::integer) && out_t.size() < 8) {
|
||||
auto info = mx::iinfo(out_t);
|
||||
if (val < info.min || val > static_cast<int64_t>(info.max)) {
|
||||
std::ostringstream msg;
|
||||
msg << "Converting " << val << " to " << out_t
|
||||
<< " would result in overflow.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
// bool_ is an exception and is always promoted
|
||||
return mx::array(val, (out_t == mx::bool_) ? mx::int32 : out_t);
|
||||
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
||||
|
@ -109,6 +109,18 @@ class TestDtypes(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max)
|
||||
self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16)
|
||||
|
||||
def test_iinfo(self):
|
||||
with self.assertRaises(ValueError):
|
||||
mx.iinfo(mx.float32)
|
||||
|
||||
self.assertEqual(mx.iinfo(mx.int32).min, np.iinfo(np.int32).min)
|
||||
self.assertEqual(mx.iinfo(mx.int32).max, np.iinfo(np.int32).max)
|
||||
self.assertEqual(mx.iinfo(mx.int32).dtype, mx.int32)
|
||||
|
||||
self.assertEqual(mx.iinfo(mx.uint32).min, np.iinfo(np.uint32).min)
|
||||
self.assertEqual(mx.iinfo(mx.uint32).max, np.iinfo(np.uint32).max)
|
||||
self.assertEqual(mx.iinfo(mx.int8).dtype, mx.int8)
|
||||
|
||||
|
||||
class TestEquality(mlx_tests.MLXTestCase):
|
||||
def test_array_eq_array(self):
|
||||
@ -1999,6 +2011,14 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
used = get_mem()
|
||||
self.assertEqual(expected, used)
|
||||
|
||||
def test_scalar_integer_conversion_overflow(self):
|
||||
y = mx.array(2000000000, dtype=mx.int32)
|
||||
x = 3000000000
|
||||
with self.assertRaises(ValueError):
|
||||
y + x
|
||||
with self.assertRaises(ValueError):
|
||||
mx.add(y, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -55,3 +55,13 @@ TEST_CASE("test finfo") {
|
||||
CHECK_EQ(finfo(float16).min, -65504);
|
||||
CHECK_EQ(finfo(float16).max, 65504);
|
||||
}
|
||||
|
||||
TEST_CASE("test iinfo") {
|
||||
CHECK_EQ(iinfo(int8).dtype, int8);
|
||||
CHECK_EQ(iinfo(int64).dtype, int64);
|
||||
CHECK_EQ(iinfo(int64).max, std::numeric_limits<int64_t>::max());
|
||||
CHECK_EQ(iinfo(uint64).max, std::numeric_limits<uint64_t>::max());
|
||||
CHECK_EQ(iinfo(uint64).max, std::numeric_limits<uint64_t>::max());
|
||||
CHECK_EQ(iinfo(uint64).min, 0);
|
||||
CHECK_EQ(iinfo(int64).min, std::numeric_limits<int64_t>::min());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user