fix int64 bug (#1860)

This commit is contained in:
Alex Barron 2025-02-12 19:23:46 -08:00 committed by GitHub
parent 0145911bea
commit 55c5ac7820
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 25 additions and 4 deletions

View File

@ -460,7 +460,12 @@ mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
return mx::array(nb::cast<bool>(*pv), t.value_or(mx::bool_));
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
return mx::array(nb::cast<int>(*pv), t.value_or(mx::int32));
auto val = nb::cast<long>(*pv);
auto default_type = (val > std::numeric_limits<int>::max() ||
val < std::numeric_limits<int>::min())
? mx::int64
: mx::int32;
return mx::array(val, t.value_or(default_type));
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
return mx::array(nb::cast<float>(*pv), t.value_or(mx::float32));
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {

View File

@ -10,10 +10,14 @@ mx::array to_array(
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
return mx::array(nb::cast<bool>(*pv), dtype.value_or(mx::bool_));
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
auto out_t = dtype.value_or(mx::int32);
auto val = nb::cast<long>(*pv);
auto default_type = (val > std::numeric_limits<int>::max() ||
val < std::numeric_limits<int>::min())
? mx::int64
: mx::int32;
auto out_t = dtype.value_or(default_type);
// bool_ is an exception and is always promoted
return mx::array(
nb::cast<int>(*pv), (out_t == mx::bool_) ? mx::int32 : out_t);
return mx::array(val, (out_t == mx::bool_) ? mx::int32 : out_t);
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
auto out_t = dtype.value_or(mx::float32);
return mx::array(

View File

@ -352,6 +352,18 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(0.0)
self.assertFalse(x)
def test_int_type(self):
x = mx.array(1)
self.assertTrue(x.dtype == mx.int32)
x = mx.array(2**32 - 1)
self.assertTrue(x.dtype == mx.int64)
x = mx.array(2**40)
self.assertTrue(x.dtype == mx.int64)
x = mx.array(2**32 - 1, dtype=mx.uint32)
self.assertTrue(x.dtype == mx.uint32)
x = mx.array([1, 2], dtype=mx.int64) + 0x80000000
self.assertTrue(x.dtype == mx.int64)
def test_construction_from_lists(self):
x = mx.array([])
self.assertEqual(x.size, 0)