iinfo and scalar overflow detection (#2009)

This commit is contained in:
Awni Hannun
2025-03-27 19:54:56 -07:00
committed by GitHub
parent bc62932984
commit 5580b47291
6 changed files with 112 additions and 0 deletions

View File

@@ -109,6 +109,18 @@ class TestDtypes(mlx_tests.MLXTestCase):
self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max)
self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16)
def test_iinfo(self):
with self.assertRaises(ValueError):
mx.iinfo(mx.float32)
self.assertEqual(mx.iinfo(mx.int32).min, np.iinfo(np.int32).min)
self.assertEqual(mx.iinfo(mx.int32).max, np.iinfo(np.int32).max)
self.assertEqual(mx.iinfo(mx.int32).dtype, mx.int32)
self.assertEqual(mx.iinfo(mx.uint32).min, np.iinfo(np.uint32).min)
self.assertEqual(mx.iinfo(mx.uint32).max, np.iinfo(np.uint32).max)
self.assertEqual(mx.iinfo(mx.int8).dtype, mx.int8)
class TestEquality(mlx_tests.MLXTestCase):
def test_array_eq_array(self):
@@ -1999,6 +2011,14 @@ class TestArray(mlx_tests.MLXTestCase):
used = get_mem()
self.assertEqual(expected, used)
def test_scalar_integer_conversion_overflow(self):
y = mx.array(2000000000, dtype=mx.int32)
x = 3000000000
with self.assertRaises(ValueError):
y + x
with self.assertRaises(ValueError):
mx.add(y, x)
if __name__ == "__main__":
unittest.main()