mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix init from double (#2861)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
This commit is contained in:
@@ -406,10 +406,16 @@ mx::array array_from_list_impl(
|
||||
}
|
||||
}
|
||||
case pyfloat: {
|
||||
std::vector<float> vals;
|
||||
fill_vector(pl, vals);
|
||||
return mx::array(
|
||||
vals.begin(), shape, specified_type.value_or(mx::float32));
|
||||
auto out_type = specified_type.value_or(mx::float32);
|
||||
if (out_type == mx::float64) {
|
||||
std::vector<double> vals;
|
||||
fill_vector(pl, vals);
|
||||
return mx::array(vals.begin(), shape, out_type);
|
||||
} else {
|
||||
std::vector<float> vals;
|
||||
fill_vector(pl, vals);
|
||||
return mx::array(vals.begin(), shape, out_type);
|
||||
}
|
||||
}
|
||||
case pycomplex: {
|
||||
std::vector<std::complex<float>> vals;
|
||||
@@ -470,7 +476,12 @@ mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
|
||||
: mx::int32;
|
||||
return mx::array(val, t.value_or(default_type));
|
||||
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
||||
return mx::array(nb::cast<float>(*pv), t.value_or(mx::float32));
|
||||
auto out_type = t.value_or(mx::float32);
|
||||
if (out_type == mx::float64) {
|
||||
return mx::array(nb::cast<double>(*pv), out_type);
|
||||
} else {
|
||||
return mx::array(nb::cast<float>(*pv), out_type);
|
||||
}
|
||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||
return mx::array(
|
||||
static_cast<mx::complex64_t>(*pv), t.value_or(mx::complex64));
|
||||
|
||||
Reference in New Issue
Block a user