diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 549d26512..32af8a078 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -622,6 +622,13 @@ std::vector split( std::vector split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) { + auto ax = axis < 0 ? axis + a.ndim() : axis; + if (ax < 0 || ax >= a.ndim()) { + std::ostringstream msg; + msg << "Invalid axis " << axis << " passed to split" + << " for array with shape " << a.shape() << "."; + throw std::invalid_argument(msg.str()); + } auto q_and_r = std::ldiv(a.shape(axis), num_splits); if (q_and_r.rem) { std::ostringstream msg; diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 5588ebd62..66e683303 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1027,6 +1027,9 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(y.tolist(), [[3, 4]]) self.assertEqual(z.tolist(), [[5, 6]]) + with self.assertRaises(ValueError): + mx.split(a, 3, axis=2) + a = mx.arange(8) x, y, z = mx.split(a, [1, 5]) self.assertEqual(x.tolist(), [0])