fix int64 bug (#1860)

This commit is contained in:
Alex Barron
2025-02-12 19:23:46 -08:00
committed by GitHub
parent 0145911bea
commit 55c5ac7820
3 changed files with 25 additions and 4 deletions

View File

@@ -460,7 +460,12 @@ mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
return mx::array(nb::cast<bool>(*pv), t.value_or(mx::bool_));
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
return mx::array(nb::cast<int>(*pv), t.value_or(mx::int32));
auto val = nb::cast<long>(*pv);
auto default_type = (val > std::numeric_limits<int>::max() ||
val < std::numeric_limits<int>::min())
? mx::int64
: mx::int32;
return mx::array(val, t.value_or(default_type));
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
return mx::array(nb::cast<float>(*pv), t.value_or(mx::float32));
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {

View File

@@ -10,10 +10,14 @@ mx::array to_array(
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
return mx::array(nb::cast<bool>(*pv), dtype.value_or(mx::bool_));
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
auto out_t = dtype.value_or(mx::int32);
auto val = nb::cast<long>(*pv);
auto default_type = (val > std::numeric_limits<int>::max() ||
val < std::numeric_limits<int>::min())
? mx::int64
: mx::int32;
auto out_t = dtype.value_or(default_type);
// bool_ is an exception and is always promoted
return mx::array(
nb::cast<int>(*pv), (out_t == mx::bool_) ? mx::int32 : out_t);
return mx::array(val, (out_t == mx::bool_) ? mx::int32 : out_t);
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
auto out_t = dtype.value_or(mx::float32);
return mx::array(