mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 19:48:15 +08:00
Add mx.array.__format__ (#1521)
* add __format__ * actually test something * fix
This commit is contained in:
@@ -845,6 +845,16 @@ void init_array(nb::module_& m) {
|
||||
nb::rv_policy::none)
|
||||
.def("__int__", [](array& a) { return nb::int_(to_scalar(a)); })
|
||||
.def("__float__", [](array& a) { return nb::float_(to_scalar(a)); })
|
||||
.def(
|
||||
"__format__",
|
||||
[](array& a, nb::object format_spec) {
|
||||
if (a.ndim() > 0) {
|
||||
throw nb::type_error(
|
||||
"unsupported format string passed to mx.array.__format__");
|
||||
}
|
||||
auto obj = to_scalar(a);
|
||||
return nb::str(PyObject_Format(obj.ptr(), format_spec.ptr()));
|
||||
})
|
||||
.def(
|
||||
"flatten",
|
||||
[](const array& a,
|
||||
|
||||
Reference in New Issue
Block a user