Add mx.array.__format__ (#1521)

* add __format__

* actually test something

* fix
This commit is contained in:
Alex Barron
2024-10-24 11:11:39 -07:00
committed by GitHub
parent c9b41d460f
commit 3d17077187
2 changed files with 20 additions and 0 deletions

View File

@@ -845,6 +845,16 @@ void init_array(nb::module_& m) {
nb::rv_policy::none)
.def("__int__", [](array& a) { return nb::int_(to_scalar(a)); })
.def("__float__", [](array& a) { return nb::float_(to_scalar(a)); })
.def(
"__format__",
[](array& a, nb::object format_spec) {
if (a.ndim() > 0) {
throw nb::type_error(
"unsupported format string passed to mx.array.__format__");
}
auto obj = to_scalar(a);
return nb::str(PyObject_Format(obj.ptr(), format_spec.ptr()));
})
.def(
"flatten",
[](const array& a,