mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-14 12:31:13 +08:00
fix format (#1539)
This commit is contained in:
parent
015c247393
commit
d2ff04a4f2
@ -848,12 +848,19 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"__format__",
|
||||
[](array& a, nb::object format_spec) {
|
||||
if (a.ndim() > 0) {
|
||||
if (nb::len(nb::str(format_spec)) > 0 && a.ndim() > 0) {
|
||||
throw nb::type_error(
|
||||
"unsupported format string passed to mx.array.__format__");
|
||||
}
|
||||
} else if (a.ndim() == 0) {
|
||||
auto obj = to_scalar(a);
|
||||
return nb::str(PyObject_Format(obj.ptr(), format_spec.ptr()));
|
||||
return nb::cast<std::string>(
|
||||
nb::handle(PyObject_Format(obj.ptr(), format_spec.ptr())));
|
||||
} else {
|
||||
nb::gil_scoped_release nogil;
|
||||
std::ostringstream os;
|
||||
os << a;
|
||||
return os.str();
|
||||
}
|
||||
})
|
||||
.def(
|
||||
"flatten",
|
||||
|
@ -1885,6 +1885,9 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
s = f"{a:.2f}"
|
||||
|
||||
a = mx.array([1, 2, 3])
|
||||
self.assertEqual(f"{a}", "array([1, 2, 3], dtype=int32)")
|
||||
|
||||
def test_deep_graphs(self):
|
||||
# The following tests should simply run cleanly without a segfault or
|
||||
# crash due to exceeding recursion depth limits.
|
||||
|
Loading…
Reference in New Issue
Block a user