diff --git a/python/src/array.cpp b/python/src/array.cpp index c9bc8eebe..699193305 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -845,6 +845,16 @@ void init_array(nb::module_& m) { nb::rv_policy::none) .def("__int__", [](array& a) { return nb::int_(to_scalar(a)); }) .def("__float__", [](array& a) { return nb::float_(to_scalar(a)); }) + .def( + "__format__", + [](array& a, nb::object format_spec) { + if (a.ndim() > 0) { + throw nb::type_error( + "unsupported format string passed to mx.array.__format__"); + } + auto obj = to_scalar(a); + return nb::str(PyObject_Format(obj.ptr(), format_spec.ptr())); + }) .def( "flatten", [](const array& a, diff --git a/python/tests/test_array.py b/python/tests/test_array.py index d1ec3427d..51968459d 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1875,6 +1875,16 @@ class TestArray(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): int(a) + def test_format(self): + a = mx.arange(3) + self.assertEqual(f"{a[0]:.2f}", "0.00") + + b = mx.array(0.35487) + self.assertEqual(f"{b:.1f}", "0.4") + + with self.assertRaises(TypeError): + s = f"{a:.2f}" + def test_deep_graphs(self): # The following tests should simply run cleanly without a segfault or # crash due to exceeding recursion depth limits.