mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-08 04:08:54 +08:00
Fix init from double (#2861)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
This commit is contained in:
@@ -406,10 +406,16 @@ mx::array array_from_list_impl(
|
||||
}
|
||||
}
|
||||
case pyfloat: {
|
||||
std::vector<float> vals;
|
||||
fill_vector(pl, vals);
|
||||
return mx::array(
|
||||
vals.begin(), shape, specified_type.value_or(mx::float32));
|
||||
auto out_type = specified_type.value_or(mx::float32);
|
||||
if (out_type == mx::float64) {
|
||||
std::vector<double> vals;
|
||||
fill_vector(pl, vals);
|
||||
return mx::array(vals.begin(), shape, out_type);
|
||||
} else {
|
||||
std::vector<float> vals;
|
||||
fill_vector(pl, vals);
|
||||
return mx::array(vals.begin(), shape, out_type);
|
||||
}
|
||||
}
|
||||
case pycomplex: {
|
||||
std::vector<std::complex<float>> vals;
|
||||
@@ -470,7 +476,12 @@ mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
|
||||
: mx::int32;
|
||||
return mx::array(val, t.value_or(default_type));
|
||||
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
||||
return mx::array(nb::cast<float>(*pv), t.value_or(mx::float32));
|
||||
auto out_type = t.value_or(mx::float32);
|
||||
if (out_type == mx::float64) {
|
||||
return mx::array(nb::cast<double>(*pv), out_type);
|
||||
} else {
|
||||
return mx::array(nb::cast<float>(*pv), out_type);
|
||||
}
|
||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||
return mx::array(
|
||||
static_cast<mx::complex64_t>(*pv), t.value_or(mx::complex64));
|
||||
|
||||
@@ -434,6 +434,14 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array([0, 4294967295], dtype=mx.float32)
|
||||
self.assertTrue(np.array_equal(x, xnp))
|
||||
|
||||
def test_double_keeps_precision(self):
|
||||
x = 39.14223403241
|
||||
out = mx.array(x, dtype=mx.float64).item()
|
||||
self.assertEqual(out, x)
|
||||
|
||||
out = mx.array([x], dtype=mx.float64).item()
|
||||
self.assertEqual(out, x)
|
||||
|
||||
def test_construction_from_lists_of_mlx_arrays(self):
|
||||
dtypes = [
|
||||
mx.bool_,
|
||||
|
||||
Reference in New Issue
Block a user