mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add mx.array.__format__ (#1521)
* add __format__ * actually test something * fix
This commit is contained in:
parent
c9b41d460f
commit
3d17077187
@ -845,6 +845,16 @@ void init_array(nb::module_& m) {
|
||||
nb::rv_policy::none)
|
||||
.def("__int__", [](array& a) { return nb::int_(to_scalar(a)); })
|
||||
.def("__float__", [](array& a) { return nb::float_(to_scalar(a)); })
|
||||
.def(
|
||||
"__format__",
|
||||
[](array& a, nb::object format_spec) {
|
||||
if (a.ndim() > 0) {
|
||||
throw nb::type_error(
|
||||
"unsupported format string passed to mx.array.__format__");
|
||||
}
|
||||
auto obj = to_scalar(a);
|
||||
return nb::str(PyObject_Format(obj.ptr(), format_spec.ptr()));
|
||||
})
|
||||
.def(
|
||||
"flatten",
|
||||
[](const array& a,
|
||||
|
@ -1875,6 +1875,16 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
int(a)
|
||||
|
||||
def test_format(self):
|
||||
a = mx.arange(3)
|
||||
self.assertEqual(f"{a[0]:.2f}", "0.00")
|
||||
|
||||
b = mx.array(0.35487)
|
||||
self.assertEqual(f"{b:.1f}", "0.4")
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
s = f"{a:.2f}"
|
||||
|
||||
def test_deep_graphs(self):
|
||||
# The following tests should simply run cleanly without a segfault or
|
||||
# crash due to exceeding recursion depth limits.
|
||||
|
Loading…
Reference in New Issue
Block a user