Fix init from double (#2861)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled

This commit is contained in:
Awni Hannun
2025-12-03 06:08:11 -08:00
committed by GitHub
parent 193cdcd81a
commit cacbdbf995
2 changed files with 24 additions and 5 deletions

View File

@@ -406,10 +406,16 @@ mx::array array_from_list_impl(
}
}
case pyfloat: {
std::vector<float> vals;
fill_vector(pl, vals);
return mx::array(
vals.begin(), shape, specified_type.value_or(mx::float32));
auto out_type = specified_type.value_or(mx::float32);
if (out_type == mx::float64) {
std::vector<double> vals;
fill_vector(pl, vals);
return mx::array(vals.begin(), shape, out_type);
} else {
std::vector<float> vals;
fill_vector(pl, vals);
return mx::array(vals.begin(), shape, out_type);
}
}
case pycomplex: {
std::vector<std::complex<float>> vals;
@@ -470,7 +476,12 @@ mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
: mx::int32;
return mx::array(val, t.value_or(default_type));
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
return mx::array(nb::cast<float>(*pv), t.value_or(mx::float32));
auto out_type = t.value_or(mx::float32);
if (out_type == mx::float64) {
return mx::array(nb::cast<double>(*pv), out_type);
} else {
return mx::array(nb::cast<float>(*pv), out_type);
}
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return mx::array(
static_cast<mx::complex64_t>(*pv), t.value_or(mx::complex64));

View File

@@ -434,6 +434,14 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array([0, 4294967295], dtype=mx.float32)
self.assertTrue(np.array_equal(x, xnp))
def test_double_keeps_precision(self):
x = 39.14223403241
out = mx.array(x, dtype=mx.float64).item()
self.assertEqual(out, x)
out = mx.array([x], dtype=mx.float64).item()
self.assertEqual(out, x)
def test_construction_from_lists_of_mlx_arrays(self):
dtypes = [
mx.bool_,