diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index 98c48cca9..44e2a432b 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -14,6 +14,8 @@ void print_constant(std::ostream& os, const array& x) { return print_float_constant(os, x); case bfloat16: return print_float_constant(os, x); + case float64: + return print_float_constant(os, x); case complex64: return print_complex_constant(os, x); case int8: @@ -50,6 +52,8 @@ std::string get_type_string(Dtype d) { return "float16_t"; case bfloat16: return "bfloat16_t"; + case float64: + return "double"; case complex64: return "complex64_t"; case bool_: diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h index 6fccaacd6..e92a6d0ad 100644 --- a/mlx/backend/common/compiled.h +++ b/mlx/backend/common/compiled.h @@ -18,8 +18,12 @@ std::string get_type_string(Dtype d); template void print_float_constant(std::ostream& os, const array& x) { auto old_precision = os.precision(); - os << std::setprecision(std::numeric_limits::digits10 + 1) - << x.item() << std::setprecision(old_precision); + if constexpr (std::is_same_v) { + os << std::setprecision(std::numeric_limits::digits10 + 1); + } else { + os << std::setprecision(std::numeric_limits::digits10 + 1); + } + os << x.item() << std::setprecision(old_precision); } template diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 00f8395fc..1340b663a 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -205,6 +205,8 @@ nb::object to_scalar(mx::array& a) { return nb::cast(static_cast(a.item())); case mx::complex64: return nb::cast(a.item>()); + case mx::float64: + return nb::cast(a.item()); default: throw nb::type_error("type cannot be converted to Python scalar."); } diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 656553f9d..ca33c2d3a 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -2,6 +2,7 @@ import gc import io +import math import unittest from functools import partial @@ -979,6 +980,17 @@ class TestCompile(mlx_tests.MLXTestCase): self.assertEqual(mem_pre, mem_post) + def test_double_constant(self): + with mx.stream(mx.cpu): + x = mx.array(1.0, dtype=mx.float64) + + def fun(x): + return (x + math.pi) * 2.0 + + y = fun(x).item() + y_compiled = mx.compile(fun)(x).item() + self.assertEqual(y, y_compiled) + if __name__ == "__main__": mlx_tests.MLXTestRunner()