From d2ff04a4f295fd68260591b1f932f148408722d7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 28 Oct 2024 18:29:14 -0700 Subject: [PATCH] fix format (#1539) --- python/src/array.cpp | 13 ++++++++++--- python/tests/test_array.py | 3 +++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index 699193305..c5a3c0cdd 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -848,12 +848,19 @@ void init_array(nb::module_& m) { .def( "__format__", [](array& a, nb::object format_spec) { - if (a.ndim() > 0) { + if (nb::len(nb::str(format_spec)) > 0 && a.ndim() > 0) { throw nb::type_error( "unsupported format string passed to mx.array.__format__"); + } else if (a.ndim() == 0) { + auto obj = to_scalar(a); + return nb::cast( + nb::handle(PyObject_Format(obj.ptr(), format_spec.ptr()))); + } else { + nb::gil_scoped_release nogil; + std::ostringstream os; + os << a; + return os.str(); } - auto obj = to_scalar(a); - return nb::str(PyObject_Format(obj.ptr(), format_spec.ptr())); }) .def( "flatten", diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 51968459d..da14675a0 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1885,6 +1885,9 @@ class TestArray(mlx_tests.MLXTestCase): with self.assertRaises(TypeError): s = f"{a:.2f}" + a = mx.array([1, 2, 3]) + self.assertEqual(f"{a}", "array([1, 2, 3], dtype=int32)") + def test_deep_graphs(self): # The following tests should simply run cleanly without a segfault or # crash due to exceeding recursion depth limits.