From 233384161eb7024b555e3e7ab6c18f358b43f2ec Mon Sep 17 00:00:00 2001 From: Manuel Villanueva <118570103+Maalvi14@users.noreply.github.com> Date: Fri, 24 Oct 2025 11:48:41 -0500 Subject: [PATCH] Improved mx.split() docs (#2689) * Improved mx.split() documentation * Fix typo in docstring for array split function * add example --------- Co-authored-by: Awni Hannun --- python/src/ops.cpp | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 27d330d6e..2bf1a7ab1 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2577,13 +2577,23 @@ void init_ops(nb::module_& m) { a (array): Input array. indices_or_sections (int or list(int)): If ``indices_or_sections`` is an integer the array is split into that many sections of equal - size. An error is raised if this is not possible. If ``indices_or_sections`` - is a list, the list contains the indices of the start of each subarray - along the given axis. + size. An error is raised if this is not possible. If + ``indices_or_sections`` is a list, then the indices are the split + points, and the array is divided into + ``len(indices_or_sections) + 1`` sub-arrays. axis (int, optional): Axis to split along, defaults to `0`. Returns: list(array): A list of split arrays. + + Example: + + >>> a = mx.array([1, 2, 3, 4], dtype=mx.int32) + >>> mx.split(a, 2) + [array([1, 2], dtype=int32), array([3, 4], dtype=int32)] + >>> mx.split(a, [1, 3]) + [array([1], dtype=int32), array([2, 3], dtype=int32), array([4], dtype=int32)] + )pbdoc"); m.def( "argmin",