From 55c5ac7820acf540dc430470f13dab261f270135 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Wed, 12 Feb 2025 19:23:46 -0800 Subject: [PATCH] fix int64 bug (#1860) --- python/src/convert.cpp | 7 ++++++- python/src/utils.cpp | 10 +++++++--- python/tests/test_array.py | 12 ++++++++++++ 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 82053100f..b88a5832a 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -460,7 +460,12 @@ mx::array create_array(ArrayInitType v, std::optional t) { if (auto pv = std::get_if(&v); pv) { return mx::array(nb::cast(*pv), t.value_or(mx::bool_)); } else if (auto pv = std::get_if(&v); pv) { - return mx::array(nb::cast(*pv), t.value_or(mx::int32)); + auto val = nb::cast(*pv); + auto default_type = (val > std::numeric_limits::max() || + val < std::numeric_limits::min()) + ? mx::int64 + : mx::int32; + return mx::array(val, t.value_or(default_type)); } else if (auto pv = std::get_if(&v); pv) { return mx::array(nb::cast(*pv), t.value_or(mx::float32)); } else if (auto pv = std::get_if>(&v); pv) { diff --git a/python/src/utils.cpp b/python/src/utils.cpp index 70dbb3ddc..e6ca346dc 100644 --- a/python/src/utils.cpp +++ b/python/src/utils.cpp @@ -10,10 +10,14 @@ mx::array to_array( if (auto pv = std::get_if(&v); pv) { return mx::array(nb::cast(*pv), dtype.value_or(mx::bool_)); } else if (auto pv = std::get_if(&v); pv) { - auto out_t = dtype.value_or(mx::int32); + auto val = nb::cast(*pv); + auto default_type = (val > std::numeric_limits::max() || + val < std::numeric_limits::min()) + ? mx::int64 + : mx::int32; + auto out_t = dtype.value_or(default_type); // bool_ is an exception and is always promoted - return mx::array( - nb::cast(*pv), (out_t == mx::bool_) ? mx::int32 : out_t); + return mx::array(val, (out_t == mx::bool_) ? mx::int32 : out_t); } else if (auto pv = std::get_if(&v); pv) { auto out_t = dtype.value_or(mx::float32); return mx::array( diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 86c061289..a5515b87f 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -352,6 +352,18 @@ class TestArray(mlx_tests.MLXTestCase): x = mx.array(0.0) self.assertFalse(x) + def test_int_type(self): + x = mx.array(1) + self.assertTrue(x.dtype == mx.int32) + x = mx.array(2**32 - 1) + self.assertTrue(x.dtype == mx.int64) + x = mx.array(2**40) + self.assertTrue(x.dtype == mx.int64) + x = mx.array(2**32 - 1, dtype=mx.uint32) + self.assertTrue(x.dtype == mx.uint32) + x = mx.array([1, 2], dtype=mx.int64) + 0x80000000 + self.assertTrue(x.dtype == mx.int64) + def test_construction_from_lists(self): x = mx.array([]) self.assertEqual(x.size, 0)