mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	iinfo and scalar overflow detection (#2009)
This commit is contained in:
		| @@ -206,6 +206,30 @@ void init_array(nb::module_& m) { | ||||
|         return os.str(); | ||||
|       }); | ||||
|  | ||||
|   nb::class_<mx::iinfo>( | ||||
|       m, | ||||
|       "iinfo", | ||||
|       R"pbdoc( | ||||
|       Get information on integer types. | ||||
|       )pbdoc") | ||||
|       .def(nb::init<mx::Dtype>()) | ||||
|       .def_ro( | ||||
|           "min", | ||||
|           &mx::iinfo::min, | ||||
|           R"pbdoc(The smallest representable number.)pbdoc") | ||||
|       .def_ro( | ||||
|           "max", | ||||
|           &mx::iinfo::max, | ||||
|           R"pbdoc(The largest representable number.)pbdoc") | ||||
|       .def_ro("dtype", &mx::iinfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc") | ||||
|       .def("__repr__", [](const mx::iinfo& i) { | ||||
|         std::ostringstream os; | ||||
|         os << "iinfo(" | ||||
|            << "min=" << i.min << ", max=" << i.max << ", dtype=" << i.dtype | ||||
|            << ")"; | ||||
|         return os.str(); | ||||
|       }); | ||||
|  | ||||
|   nb::class_<ArrayAt>( | ||||
|       m, | ||||
|       "ArrayAt", | ||||
|   | ||||
| @@ -2,6 +2,7 @@ | ||||
|  | ||||
| #include "python/src/utils.h" | ||||
| #include "mlx/ops.h" | ||||
| #include "mlx/utils.h" | ||||
| #include "python/src/convert.h" | ||||
|  | ||||
| mx::array to_array( | ||||
| @@ -16,6 +17,16 @@ mx::array to_array( | ||||
|         ? mx::int64 | ||||
|         : mx::int32; | ||||
|     auto out_t = dtype.value_or(default_type); | ||||
|     if (mx::issubdtype(out_t, mx::integer) && out_t.size() < 8) { | ||||
|       auto info = mx::iinfo(out_t); | ||||
|       if (val < info.min || val > static_cast<int64_t>(info.max)) { | ||||
|         std::ostringstream msg; | ||||
|         msg << "Converting " << val << " to " << out_t | ||||
|             << " would result in overflow."; | ||||
|         throw std::invalid_argument(msg.str()); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     // bool_ is an exception and is always promoted | ||||
|     return mx::array(val, (out_t == mx::bool_) ? mx::int32 : out_t); | ||||
|   } else if (auto pv = std::get_if<nb::float_>(&v); pv) { | ||||
|   | ||||
| @@ -109,6 +109,18 @@ class TestDtypes(mlx_tests.MLXTestCase): | ||||
|         self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max) | ||||
|         self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16) | ||||
|  | ||||
|     def test_iinfo(self): | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.iinfo(mx.float32) | ||||
|  | ||||
|         self.assertEqual(mx.iinfo(mx.int32).min, np.iinfo(np.int32).min) | ||||
|         self.assertEqual(mx.iinfo(mx.int32).max, np.iinfo(np.int32).max) | ||||
|         self.assertEqual(mx.iinfo(mx.int32).dtype, mx.int32) | ||||
|  | ||||
|         self.assertEqual(mx.iinfo(mx.uint32).min, np.iinfo(np.uint32).min) | ||||
|         self.assertEqual(mx.iinfo(mx.uint32).max, np.iinfo(np.uint32).max) | ||||
|         self.assertEqual(mx.iinfo(mx.int8).dtype, mx.int8) | ||||
|  | ||||
|  | ||||
| class TestEquality(mlx_tests.MLXTestCase): | ||||
|     def test_array_eq_array(self): | ||||
| @@ -1999,6 +2011,14 @@ class TestArray(mlx_tests.MLXTestCase): | ||||
|         used = get_mem() | ||||
|         self.assertEqual(expected, used) | ||||
|  | ||||
|     def test_scalar_integer_conversion_overflow(self): | ||||
|         y = mx.array(2000000000, dtype=mx.int32) | ||||
|         x = 3000000000 | ||||
|         with self.assertRaises(ValueError): | ||||
|             y + x | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.add(y, x) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun