improved error msg for invalid axis(mx.split) (#685)

* improved error msg for invalid axis(`mx.split`)

* Apply suggestions from code review

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fixed formatting issue

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
toji 2024-02-15 20:55:38 +05:30 committed by GitHub
parent 35431a4ac8
commit 85143fecdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 0 deletions

View File

@ -622,6 +622,13 @@ std::vector<array> split(
std::vector<array> std::vector<array>
split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) { split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) {
auto ax = axis < 0 ? axis + a.ndim() : axis;
if (ax < 0 || ax >= a.ndim()) {
std::ostringstream msg;
msg << "Invalid axis " << axis << " passed to split"
<< " for array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}
auto q_and_r = std::ldiv(a.shape(axis), num_splits); auto q_and_r = std::ldiv(a.shape(axis), num_splits);
if (q_and_r.rem) { if (q_and_r.rem) {
std::ostringstream msg; std::ostringstream msg;

View File

@ -1027,6 +1027,9 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(y.tolist(), [[3, 4]]) self.assertEqual(y.tolist(), [[3, 4]])
self.assertEqual(z.tolist(), [[5, 6]]) self.assertEqual(z.tolist(), [[5, 6]])
with self.assertRaises(ValueError):
mx.split(a, 3, axis=2)
a = mx.arange(8) a = mx.arange(8)
x, y, z = mx.split(a, [1, 5]) x, y, z = mx.split(a, [1, 5])
self.assertEqual(x.tolist(), [0]) self.assertEqual(x.tolist(), [0])