fix format (#1539)

This commit is contained in:
Awni Hannun 2024-10-28 18:29:14 -07:00 committed by GitHub
parent 015c247393
commit d2ff04a4f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 3 deletions

View File

@ -848,12 +848,19 @@ void init_array(nb::module_& m) {
.def(
"__format__",
[](array& a, nb::object format_spec) {
if (a.ndim() > 0) {
if (nb::len(nb::str(format_spec)) > 0 && a.ndim() > 0) {
throw nb::type_error(
"unsupported format string passed to mx.array.__format__");
}
} else if (a.ndim() == 0) {
auto obj = to_scalar(a);
return nb::str(PyObject_Format(obj.ptr(), format_spec.ptr()));
return nb::cast<std::string>(
nb::handle(PyObject_Format(obj.ptr(), format_spec.ptr())));
} else {
nb::gil_scoped_release nogil;
std::ostringstream os;
os << a;
return os.str();
}
})
.def(
"flatten",

View File

@ -1885,6 +1885,9 @@ class TestArray(mlx_tests.MLXTestCase):
with self.assertRaises(TypeError):
s = f"{a:.2f}"
a = mx.array([1, 2, 3])
self.assertEqual(f"{a}", "array([1, 2, 3], dtype=int32)")
def test_deep_graphs(self):
# The following tests should simply run cleanly without a segfault or
# crash due to exceeding recursion depth limits.