Add mx.array.__format__ (#1521)

* add __format__

* actually test something

* fix
This commit is contained in:
Alex Barron
2024-10-24 11:11:39 -07:00
committed by GitHub
parent c9b41d460f
commit 3d17077187
2 changed files with 20 additions and 0 deletions

View File

@@ -1875,6 +1875,16 @@ class TestArray(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
int(a)
def test_format(self):
a = mx.arange(3)
self.assertEqual(f"{a[0]:.2f}", "0.00")
b = mx.array(0.35487)
self.assertEqual(f"{b:.1f}", "0.4")
with self.assertRaises(TypeError):
s = f"{a:.2f}"
def test_deep_graphs(self):
# The following tests should simply run cleanly without a segfault or
# crash due to exceeding recursion depth limits.