diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 88da37103..0367aceb0 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -406,10 +406,16 @@ mx::array array_from_list_impl( } } case pyfloat: { - std::vector vals; - fill_vector(pl, vals); - return mx::array( - vals.begin(), shape, specified_type.value_or(mx::float32)); + auto out_type = specified_type.value_or(mx::float32); + if (out_type == mx::float64) { + std::vector vals; + fill_vector(pl, vals); + return mx::array(vals.begin(), shape, out_type); + } else { + std::vector vals; + fill_vector(pl, vals); + return mx::array(vals.begin(), shape, out_type); + } } case pycomplex: { std::vector> vals; @@ -470,7 +476,12 @@ mx::array create_array(ArrayInitType v, std::optional t) { : mx::int32; return mx::array(val, t.value_or(default_type)); } else if (auto pv = std::get_if(&v); pv) { - return mx::array(nb::cast(*pv), t.value_or(mx::float32)); + auto out_type = t.value_or(mx::float32); + if (out_type == mx::float64) { + return mx::array(nb::cast(*pv), out_type); + } else { + return mx::array(nb::cast(*pv), out_type); + } } else if (auto pv = std::get_if>(&v); pv) { return mx::array( static_cast(*pv), t.value_or(mx::complex64)); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index e2aa74f04..601e61674 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -434,6 +434,14 @@ class TestArray(mlx_tests.MLXTestCase): x = mx.array([0, 4294967295], dtype=mx.float32) self.assertTrue(np.array_equal(x, xnp)) + def test_double_keeps_precision(self): + x = 39.14223403241 + out = mx.array(x, dtype=mx.float64).item() + self.assertEqual(out, x) + + out = mx.array([x], dtype=mx.float64).item() + self.assertEqual(out, x) + def test_construction_from_lists_of_mlx_arrays(self): dtypes = [ mx.bool_,