fix int64 bug (#1860)

This commit is contained in:
Alex Barron
2025-02-12 19:23:46 -08:00
committed by GitHub
parent 0145911bea
commit 55c5ac7820
3 changed files with 25 additions and 4 deletions

View File

@@ -352,6 +352,18 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(0.0)
self.assertFalse(x)
def test_int_type(self):
x = mx.array(1)
self.assertTrue(x.dtype == mx.int32)
x = mx.array(2**32 - 1)
self.assertTrue(x.dtype == mx.int64)
x = mx.array(2**40)
self.assertTrue(x.dtype == mx.int64)
x = mx.array(2**32 - 1, dtype=mx.uint32)
self.assertTrue(x.dtype == mx.uint32)
x = mx.array([1, 2], dtype=mx.int64) + 0x80000000
self.assertTrue(x.dtype == mx.int64)
def test_construction_from_lists(self):
x = mx.array([])
self.assertEqual(x.size, 0)