improved error msg for invalid axis(mx.split) (#685)

* improved error msg for invalid axis(`mx.split`)

* Apply suggestions from code review

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fixed formatting issue

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
toji
2024-02-15 20:55:38 +05:30
committed by GitHub
parent 35431a4ac8
commit 85143fecdd
2 changed files with 10 additions and 0 deletions

View File

@@ -1027,6 +1027,9 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(y.tolist(), [[3, 4]])
self.assertEqual(z.tolist(), [[5, 6]])
with self.assertRaises(ValueError):
mx.split(a, 3, axis=2)
a = mx.arange(8)
x, y, z = mx.split(a, [1, 5])
self.assertEqual(x.tolist(), [0])