Improved mx.split() docs (#2689)

* Improved mx.split() documentation

* Fix typo in docstring for array split function

* add example

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Manuel Villanueva
2025-10-24 11:48:41 -05:00
committed by GitHub
parent 5bcf3a6794
commit 233384161e

View File

@@ -2577,13 +2577,23 @@ void init_ops(nb::module_& m) {
a (array): Input array. a (array): Input array.
indices_or_sections (int or list(int)): If ``indices_or_sections`` indices_or_sections (int or list(int)): If ``indices_or_sections``
is an integer the array is split into that many sections of equal is an integer the array is split into that many sections of equal
size. An error is raised if this is not possible. If ``indices_or_sections`` size. An error is raised if this is not possible. If
is a list, the list contains the indices of the start of each subarray ``indices_or_sections`` is a list, then the indices are the split
along the given axis. points, and the array is divided into
``len(indices_or_sections) + 1`` sub-arrays.
axis (int, optional): Axis to split along, defaults to `0`. axis (int, optional): Axis to split along, defaults to `0`.
Returns: Returns:
list(array): A list of split arrays. list(array): A list of split arrays.
Example:
>>> a = mx.array([1, 2, 3, 4], dtype=mx.int32)
>>> mx.split(a, 2)
[array([1, 2], dtype=int32), array([3, 4], dtype=int32)]
>>> mx.split(a, [1, 3])
[array([1], dtype=int32), array([2, 3], dtype=int32), array([4], dtype=int32)]
)pbdoc"); )pbdoc");
m.def( m.def(
"argmin", "argmin",