2024-01-10 14:04:12 -08:00
|
|
|
.. _indexing:
|
|
|
|
|
|
|
|
|
|
Indexing Arrays
|
|
|
|
|
===============
|
|
|
|
|
|
|
|
|
|
.. currentmodule:: mlx.core
|
|
|
|
|
|
|
|
|
|
For the most part, indexing an MLX :obj:`array` works the same as indexing a
|
|
|
|
|
NumPy :obj:`numpy.ndarray`. See the `NumPy documentation
|
|
|
|
|
<https://numpy.org/doc/stable/user/basics.indexing.html>`_ for more details on
|
|
|
|
|
how that works.
|
|
|
|
|
|
|
|
|
|
For example, you can use regular integers and slices (:obj:`slice`) to index arrays:
|
|
|
|
|
|
|
|
|
|
.. code-block:: shell
|
|
|
|
|
|
|
|
|
|
>>> arr = mx.arange(10)
|
|
|
|
|
>>> arr[3]
|
|
|
|
|
array(3, dtype=int32)
|
|
|
|
|
>>> arr[-2] # negative indexing works
|
|
|
|
|
array(8, dtype=int32)
|
|
|
|
|
>>> arr[2:8:2] # start, stop, stride
|
|
|
|
|
array([2, 4, 6], dtype=int32)
|
|
|
|
|
|
|
|
|
|
For multi-dimensional arrays, the ``...`` or :obj:`Ellipsis` syntax works as in NumPy:
|
|
|
|
|
|
|
|
|
|
.. code-block:: shell
|
|
|
|
|
|
|
|
|
|
>>> arr = mx.arange(8).reshape(2, 2, 2)
|
|
|
|
|
>>> arr[:, :, 0]
|
|
|
|
|
array(3, dtype=int32)
|
|
|
|
|
array([[0, 2],
|
|
|
|
|
[4, 6]], dtype=int32
|
|
|
|
|
>>> arr[..., 0]
|
|
|
|
|
array([[0, 2],
|
|
|
|
|
[4, 6]], dtype=int32
|
|
|
|
|
|
|
|
|
|
You can index with ``None`` to create a new axis:
|
|
|
|
|
|
|
|
|
|
.. code-block:: shell
|
|
|
|
|
|
|
|
|
|
>>> arr = mx.arange(8)
|
|
|
|
|
>>> arr.shape
|
|
|
|
|
[8]
|
|
|
|
|
>>> arr[None].shape
|
|
|
|
|
[1, 8]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
You can also use an :obj:`array` to index another :obj:`array`:
|
|
|
|
|
|
|
|
|
|
.. code-block:: shell
|
|
|
|
|
|
|
|
|
|
>>> arr = mx.arange(10)
|
2024-10-21 19:33:32 -07:00
|
|
|
>>> idx = mx.array([5, 7])
|
2024-01-10 14:04:12 -08:00
|
|
|
>>> arr[idx]
|
|
|
|
|
array([5, 7], dtype=int32)
|
|
|
|
|
|
|
|
|
|
Mixing and matching integers, :obj:`slice`, ``...``, and :obj:`array` indices
|
|
|
|
|
works just as in NumPy.
|
|
|
|
|
|
|
|
|
|
Other functions which may be useful for indexing arrays are :func:`take` and
|
|
|
|
|
:func:`take_along_axis`.
|
|
|
|
|
|
|
|
|
|
Differences from NumPy
|
|
|
|
|
----------------------
|
|
|
|
|
|
|
|
|
|
.. Note::
|
|
|
|
|
|
|
|
|
|
MLX indexing is different from NumPy indexing in two important ways:
|
|
|
|
|
|
|
|
|
|
* Indexing does not perform bounds checking. Indexing out of bounds is
|
|
|
|
|
undefined behavior.
|
2025-11-19 23:53:32 +01:00
|
|
|
* Boolean mask based indexing is supported for assignment only (see
|
|
|
|
|
:ref:`boolean-mask-assignment`).
|
2024-01-10 14:04:12 -08:00
|
|
|
|
|
|
|
|
The reason for the lack of bounds checking is that exceptions cannot propagate
|
|
|
|
|
from the GPU. Performing bounds checking for array indices before launching the
|
|
|
|
|
kernel would be extremely inefficient.
|
|
|
|
|
|
|
|
|
|
Indexing with boolean masks is something that MLX may support in the future. In
|
2024-11-03 01:44:14 +01:00
|
|
|
general, MLX has limited support for operations for which output
|
2024-01-10 14:04:12 -08:00
|
|
|
*shapes* are dependent on input *data*. Other examples of these types of
|
|
|
|
|
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
|
|
|
|
single input version of :func:`numpy.where`.
|
|
|
|
|
|
2024-10-21 19:33:32 -07:00
|
|
|
In Place Updates
|
2024-01-10 14:04:12 -08:00
|
|
|
----------------
|
|
|
|
|
|
|
|
|
|
In place updates to indexed arrays are possible in MLX. For example:
|
|
|
|
|
|
|
|
|
|
.. code-block:: shell
|
|
|
|
|
|
|
|
|
|
>>> a = mx.array([1, 2, 3])
|
|
|
|
|
>>> a[2] = 0
|
|
|
|
|
>>> a
|
|
|
|
|
array([1, 2, 0], dtype=int32)
|
|
|
|
|
|
|
|
|
|
Just as in NumPy, in place updates will be reflected in all references to the
|
|
|
|
|
same array:
|
|
|
|
|
|
|
|
|
|
.. code-block:: shell
|
|
|
|
|
|
|
|
|
|
>>> a = mx.array([1, 2, 3])
|
|
|
|
|
>>> b = a
|
|
|
|
|
>>> b[2] = 0
|
|
|
|
|
>>> b
|
|
|
|
|
array([1, 2, 0], dtype=int32)
|
|
|
|
|
>>> a
|
|
|
|
|
array([1, 2, 0], dtype=int32)
|
|
|
|
|
|
2025-09-02 22:07:02 -07:00
|
|
|
Note that unlike NumPy, slicing an array creates a copy, not a view. So
|
|
|
|
|
mutating it does not mutate the original array:
|
2025-06-16 08:45:40 -07:00
|
|
|
|
2025-09-02 22:07:02 -07:00
|
|
|
.. code-block:: shell
|
|
|
|
|
|
|
|
|
|
>>> a = mx.array([1, 2, 3])
|
|
|
|
|
>>> b = a[:]
|
|
|
|
|
>>> b[2] = 0
|
|
|
|
|
>>> b
|
|
|
|
|
array([1, 2, 0], dtype=int32)
|
|
|
|
|
>>> a
|
|
|
|
|
array([1, 2, 3], dtype=int32)
|
|
|
|
|
|
|
|
|
|
Also unlike NumPy, updates to the same location are nondeterministic:
|
2025-06-16 08:45:40 -07:00
|
|
|
|
|
|
|
|
.. code-block:: shell
|
|
|
|
|
|
|
|
|
|
>>> a = mx.array([1, 2, 3])
|
|
|
|
|
>>> a[[0, 0]] = mx.array([4, 5])
|
|
|
|
|
|
|
|
|
|
The first element of ``a`` could be ``4`` or ``5``.
|
|
|
|
|
|
2024-01-10 14:04:12 -08:00
|
|
|
Transformations of functions which use in-place updates are allowed and work as
|
|
|
|
|
expected. For example:
|
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
def fun(x, idx):
|
|
|
|
|
x[idx] = 2.0
|
|
|
|
|
return x.sum()
|
|
|
|
|
|
|
|
|
|
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
|
|
|
|
|
print(dfdx) # Prints: array([1, 0, 1], dtype=float32)
|
|
|
|
|
|
|
|
|
|
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
|
|
|
|
|
and ones elsewhere.
|
2025-11-19 23:53:32 +01:00
|
|
|
|
|
|
|
|
.. _boolean-mask-assignment:
|
|
|
|
|
|
|
|
|
|
Boolean Mask Assignment
|
|
|
|
|
-----------------------
|
|
|
|
|
|
|
|
|
|
MLX supports boolean indices using NumPy syntax. A mask must already be
|
|
|
|
|
a :class:`bool_` MLX :class:`array` or a NumPy ``ndarray`` with ``dtype=bool``.
|
|
|
|
|
Other index types are routed through the standard scatter code.
|
|
|
|
|
|
|
|
|
|
.. code-block:: shell
|
|
|
|
|
|
|
|
|
|
>>> a = mx.array([1.0, 2.0, 3.0])
|
|
|
|
|
>>> mask = mx.array([True, False, True])
|
|
|
|
|
>>> updates = mx.array([5.0, 6.0])
|
|
|
|
|
>>> a[mask] = updates
|
|
|
|
|
>>> a
|
|
|
|
|
array([5.0, 2.0, 6.0], dtype=float32)
|
|
|
|
|
|
|
|
|
|
Scalar assignments broadcast to every ``True`` entry in ``mask``. For non-scalar
|
|
|
|
|
assignments, ``updates`` must provide at least as many elements as there are
|
|
|
|
|
``True`` entries in ``mask``.
|
|
|
|
|
|
|
|
|
|
.. code-block:: shell
|
|
|
|
|
|
|
|
|
|
>>> a = mx.zeros((2, 3))
|
|
|
|
|
>>> mask = mx.array([[True, False, True],
|
|
|
|
|
[False, False, True]])
|
|
|
|
|
>>> a[mask] = 1.0
|
|
|
|
|
>>> a
|
|
|
|
|
array([[1.0, 0.0, 1.0],
|
|
|
|
|
[0.0, 0.0, 1.0]], dtype=float32)
|
|
|
|
|
|
|
|
|
|
Boolean masks follow NumPy semantics:
|
|
|
|
|
|
|
|
|
|
- The mask shape must match the shape of the axes it indexes exactly. No mask
|
|
|
|
|
broadcasting occurs.
|
|
|
|
|
- Any axes not covered by the mask are taken in full.
|
|
|
|
|
|
|
|
|
|
.. code-block:: shell
|
|
|
|
|
|
|
|
|
|
>>> a = mx.arange(1000).reshape(10, 10, 10)
|
|
|
|
|
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
|
|
|
|
|
|
|
|
|
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
|
|
|
|
|
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
|
|
|
|
|
Shapes such as ``(1, 10, 10)`` or ``(10, 10, 1)`` do not match the indexed
|
|
|
|
|
axes and therefore raise errors.
|