fix format (#1539)

This commit is contained in:
Awni Hannun
2024-10-28 18:29:14 -07:00
committed by GitHub
parent 015c247393
commit d2ff04a4f2
2 changed files with 13 additions and 3 deletions

View File

@@ -848,12 +848,19 @@ void init_array(nb::module_& m) {
.def(
"__format__",
[](array& a, nb::object format_spec) {
if (a.ndim() > 0) {
if (nb::len(nb::str(format_spec)) > 0 && a.ndim() > 0) {
throw nb::type_error(
"unsupported format string passed to mx.array.__format__");
} else if (a.ndim() == 0) {
auto obj = to_scalar(a);
return nb::cast<std::string>(
nb::handle(PyObject_Format(obj.ptr(), format_spec.ptr())));
} else {
nb::gil_scoped_release nogil;
std::ostringstream os;
os << a;
return os.str();
}
auto obj = to_scalar(a);
return nb::str(PyObject_Format(obj.ptr(), format_spec.ptr()));
})
.def(
"flatten",