mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
fix int64 bug (#1860)
This commit is contained in:
parent
0145911bea
commit
55c5ac7820
@ -460,7 +460,12 @@ mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
|
||||
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
|
||||
return mx::array(nb::cast<bool>(*pv), t.value_or(mx::bool_));
|
||||
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
|
||||
return mx::array(nb::cast<int>(*pv), t.value_or(mx::int32));
|
||||
auto val = nb::cast<long>(*pv);
|
||||
auto default_type = (val > std::numeric_limits<int>::max() ||
|
||||
val < std::numeric_limits<int>::min())
|
||||
? mx::int64
|
||||
: mx::int32;
|
||||
return mx::array(val, t.value_or(default_type));
|
||||
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
||||
return mx::array(nb::cast<float>(*pv), t.value_or(mx::float32));
|
||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||
|
@ -10,10 +10,14 @@ mx::array to_array(
|
||||
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
|
||||
return mx::array(nb::cast<bool>(*pv), dtype.value_or(mx::bool_));
|
||||
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
|
||||
auto out_t = dtype.value_or(mx::int32);
|
||||
auto val = nb::cast<long>(*pv);
|
||||
auto default_type = (val > std::numeric_limits<int>::max() ||
|
||||
val < std::numeric_limits<int>::min())
|
||||
? mx::int64
|
||||
: mx::int32;
|
||||
auto out_t = dtype.value_or(default_type);
|
||||
// bool_ is an exception and is always promoted
|
||||
return mx::array(
|
||||
nb::cast<int>(*pv), (out_t == mx::bool_) ? mx::int32 : out_t);
|
||||
return mx::array(val, (out_t == mx::bool_) ? mx::int32 : out_t);
|
||||
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
||||
auto out_t = dtype.value_or(mx::float32);
|
||||
return mx::array(
|
||||
|
@ -352,6 +352,18 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array(0.0)
|
||||
self.assertFalse(x)
|
||||
|
||||
def test_int_type(self):
|
||||
x = mx.array(1)
|
||||
self.assertTrue(x.dtype == mx.int32)
|
||||
x = mx.array(2**32 - 1)
|
||||
self.assertTrue(x.dtype == mx.int64)
|
||||
x = mx.array(2**40)
|
||||
self.assertTrue(x.dtype == mx.int64)
|
||||
x = mx.array(2**32 - 1, dtype=mx.uint32)
|
||||
self.assertTrue(x.dtype == mx.uint32)
|
||||
x = mx.array([1, 2], dtype=mx.int64) + 0x80000000
|
||||
self.assertTrue(x.dtype == mx.int64)
|
||||
|
||||
def test_construction_from_lists(self):
|
||||
x = mx.array([])
|
||||
self.assertEqual(x.size, 0)
|
||||
|
Loading…
Reference in New Issue
Block a user