mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
fix int64 bug (#1860)
This commit is contained in:
@@ -460,7 +460,12 @@ mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
|
||||
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
|
||||
return mx::array(nb::cast<bool>(*pv), t.value_or(mx::bool_));
|
||||
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
|
||||
return mx::array(nb::cast<int>(*pv), t.value_or(mx::int32));
|
||||
auto val = nb::cast<long>(*pv);
|
||||
auto default_type = (val > std::numeric_limits<int>::max() ||
|
||||
val < std::numeric_limits<int>::min())
|
||||
? mx::int64
|
||||
: mx::int32;
|
||||
return mx::array(val, t.value_or(default_type));
|
||||
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
||||
return mx::array(nb::cast<float>(*pv), t.value_or(mx::float32));
|
||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||
|
@@ -10,10 +10,14 @@ mx::array to_array(
|
||||
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
|
||||
return mx::array(nb::cast<bool>(*pv), dtype.value_or(mx::bool_));
|
||||
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
|
||||
auto out_t = dtype.value_or(mx::int32);
|
||||
auto val = nb::cast<long>(*pv);
|
||||
auto default_type = (val > std::numeric_limits<int>::max() ||
|
||||
val < std::numeric_limits<int>::min())
|
||||
? mx::int64
|
||||
: mx::int32;
|
||||
auto out_t = dtype.value_or(default_type);
|
||||
// bool_ is an exception and is always promoted
|
||||
return mx::array(
|
||||
nb::cast<int>(*pv), (out_t == mx::bool_) ? mx::int32 : out_t);
|
||||
return mx::array(val, (out_t == mx::bool_) ? mx::int32 : out_t);
|
||||
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
||||
auto out_t = dtype.value_or(mx::float32);
|
||||
return mx::array(
|
||||
|
Reference in New Issue
Block a user