mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
iinfo and scalar overflow detection (#2009)
This commit is contained in:
@@ -109,6 +109,18 @@ class TestDtypes(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max)
|
||||
self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16)
|
||||
|
||||
def test_iinfo(self):
|
||||
with self.assertRaises(ValueError):
|
||||
mx.iinfo(mx.float32)
|
||||
|
||||
self.assertEqual(mx.iinfo(mx.int32).min, np.iinfo(np.int32).min)
|
||||
self.assertEqual(mx.iinfo(mx.int32).max, np.iinfo(np.int32).max)
|
||||
self.assertEqual(mx.iinfo(mx.int32).dtype, mx.int32)
|
||||
|
||||
self.assertEqual(mx.iinfo(mx.uint32).min, np.iinfo(np.uint32).min)
|
||||
self.assertEqual(mx.iinfo(mx.uint32).max, np.iinfo(np.uint32).max)
|
||||
self.assertEqual(mx.iinfo(mx.int8).dtype, mx.int8)
|
||||
|
||||
|
||||
class TestEquality(mlx_tests.MLXTestCase):
|
||||
def test_array_eq_array(self):
|
||||
@@ -1999,6 +2011,14 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
used = get_mem()
|
||||
self.assertEqual(expected, used)
|
||||
|
||||
def test_scalar_integer_conversion_overflow(self):
|
||||
y = mx.array(2000000000, dtype=mx.int32)
|
||||
x = 3000000000
|
||||
with self.assertRaises(ValueError):
|
||||
y + x
|
||||
with self.assertRaises(ValueError):
|
||||
mx.add(y, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user