From 85143fecdd87e7b2e390aa0c39123195ef485cf6 Mon Sep 17 00:00:00 2001 From: toji Date: Thu, 15 Feb 2024 20:55:38 +0530 Subject: [PATCH] improved error msg for invalid axis(`mx.split`) (#685) * improved error msg for invalid axis(`mx.split`) * Apply suggestions from code review Co-authored-by: Awni Hannun * fixed formatting issue --------- Co-authored-by: Awni Hannun --- mlx/ops.cpp | 7 +++++++ python/tests/test_ops.py | 3 +++ 2 files changed, 10 insertions(+) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 549d26512..32af8a078 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -622,6 +622,13 @@ std::vector split( std::vector split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) { + auto ax = axis < 0 ? axis + a.ndim() : axis; + if (ax < 0 || ax >= a.ndim()) { + std::ostringstream msg; + msg << "Invalid axis " << axis << " passed to split" + << " for array with shape " << a.shape() << "."; + throw std::invalid_argument(msg.str()); + } auto q_and_r = std::ldiv(a.shape(axis), num_splits); if (q_and_r.rem) { std::ostringstream msg; diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 5588ebd62..66e683303 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1027,6 +1027,9 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(y.tolist(), [[3, 4]]) self.assertEqual(z.tolist(), [[5, 6]]) + with self.assertRaises(ValueError): + mx.split(a, 3, axis=2) + a = mx.arange(8) x, y, z = mx.split(a, [1, 5]) self.assertEqual(x.tolist(), [0])