mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-14 20:41:13 +08:00
fix format (#1539)
This commit is contained in:
parent
015c247393
commit
d2ff04a4f2
@ -848,12 +848,19 @@ void init_array(nb::module_& m) {
|
|||||||
.def(
|
.def(
|
||||||
"__format__",
|
"__format__",
|
||||||
[](array& a, nb::object format_spec) {
|
[](array& a, nb::object format_spec) {
|
||||||
if (a.ndim() > 0) {
|
if (nb::len(nb::str(format_spec)) > 0 && a.ndim() > 0) {
|
||||||
throw nb::type_error(
|
throw nb::type_error(
|
||||||
"unsupported format string passed to mx.array.__format__");
|
"unsupported format string passed to mx.array.__format__");
|
||||||
|
} else if (a.ndim() == 0) {
|
||||||
|
auto obj = to_scalar(a);
|
||||||
|
return nb::cast<std::string>(
|
||||||
|
nb::handle(PyObject_Format(obj.ptr(), format_spec.ptr())));
|
||||||
|
} else {
|
||||||
|
nb::gil_scoped_release nogil;
|
||||||
|
std::ostringstream os;
|
||||||
|
os << a;
|
||||||
|
return os.str();
|
||||||
}
|
}
|
||||||
auto obj = to_scalar(a);
|
|
||||||
return nb::str(PyObject_Format(obj.ptr(), format_spec.ptr()));
|
|
||||||
})
|
})
|
||||||
.def(
|
.def(
|
||||||
"flatten",
|
"flatten",
|
||||||
|
@ -1885,6 +1885,9 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
s = f"{a:.2f}"
|
s = f"{a:.2f}"
|
||||||
|
|
||||||
|
a = mx.array([1, 2, 3])
|
||||||
|
self.assertEqual(f"{a}", "array([1, 2, 3], dtype=int32)")
|
||||||
|
|
||||||
def test_deep_graphs(self):
|
def test_deep_graphs(self):
|
||||||
# The following tests should simply run cleanly without a segfault or
|
# The following tests should simply run cleanly without a segfault or
|
||||||
# crash due to exceeding recursion depth limits.
|
# crash due to exceeding recursion depth limits.
|
||||||
|
Loading…
Reference in New Issue
Block a user