Add mx.array.__format__ (#1521)

* add __format__

* actually test something

* fix
This commit is contained in:
Alex Barron 2024-10-24 11:11:39 -07:00 committed by GitHub
parent c9b41d460f
commit 3d17077187
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 0 deletions

View File

@ -845,6 +845,16 @@ void init_array(nb::module_& m) {
nb::rv_policy::none)
.def("__int__", [](array& a) { return nb::int_(to_scalar(a)); })
.def("__float__", [](array& a) { return nb::float_(to_scalar(a)); })
.def(
"__format__",
[](array& a, nb::object format_spec) {
if (a.ndim() > 0) {
throw nb::type_error(
"unsupported format string passed to mx.array.__format__");
}
auto obj = to_scalar(a);
return nb::str(PyObject_Format(obj.ptr(), format_spec.ptr()));
})
.def(
"flatten",
[](const array& a,

View File

@ -1875,6 +1875,16 @@ class TestArray(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
int(a)
def test_format(self):
a = mx.arange(3)
self.assertEqual(f"{a[0]:.2f}", "0.00")
b = mx.array(0.35487)
self.assertEqual(f"{b:.1f}", "0.4")
with self.assertRaises(TypeError):
s = f"{a:.2f}"
def test_deep_graphs(self):
# The following tests should simply run cleanly without a segfault or
# crash due to exceeding recursion depth limits.