mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add Masked Scatter (#2663)
Co-authored-by: Awni Hannun <awni@apple.com> Co-authored-by: Angelos Katharopoulos <katharas@gmail.com> Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -70,7 +70,8 @@ Differences from NumPy
|
||||
|
||||
* Indexing does not perform bounds checking. Indexing out of bounds is
|
||||
undefined behavior.
|
||||
* Boolean mask based indexing is not yet supported.
|
||||
* Boolean mask based indexing is supported for assignment only (see
|
||||
:ref:`boolean-mask-assignment`).
|
||||
|
||||
The reason for the lack of bounds checking is that exceptions cannot propagate
|
||||
from the GPU. Performing bounds checking for array indices before launching the
|
||||
@@ -143,3 +144,51 @@ expected. For example:
|
||||
|
||||
In the above ``dfdx`` will have the correct gradient, namely zeros at ``idx``
|
||||
and ones elsewhere.
|
||||
|
||||
.. _boolean-mask-assignment:
|
||||
|
||||
Boolean Mask Assignment
|
||||
-----------------------
|
||||
|
||||
MLX supports boolean indices using NumPy syntax. A mask must already be
|
||||
a :class:`bool_` MLX :class:`array` or a NumPy ``ndarray`` with ``dtype=bool``.
|
||||
Other index types are routed through the standard scatter code.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1.0, 2.0, 3.0])
|
||||
>>> mask = mx.array([True, False, True])
|
||||
>>> updates = mx.array([5.0, 6.0])
|
||||
>>> a[mask] = updates
|
||||
>>> a
|
||||
array([5.0, 2.0, 6.0], dtype=float32)
|
||||
|
||||
Scalar assignments broadcast to every ``True`` entry in ``mask``. For non-scalar
|
||||
assignments, ``updates`` must provide at least as many elements as there are
|
||||
``True`` entries in ``mask``.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.zeros((2, 3))
|
||||
>>> mask = mx.array([[True, False, True],
|
||||
[False, False, True]])
|
||||
>>> a[mask] = 1.0
|
||||
>>> a
|
||||
array([[1.0, 0.0, 1.0],
|
||||
[0.0, 0.0, 1.0]], dtype=float32)
|
||||
|
||||
Boolean masks follow NumPy semantics:
|
||||
|
||||
- The mask shape must match the shape of the axes it indexes exactly. No mask
|
||||
broadcasting occurs.
|
||||
- Any axes not covered by the mask are taken in full.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.arange(1000).reshape(10, 10, 10)
|
||||
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
||||
|
||||
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
|
||||
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
|
||||
Shapes such as ``(1, 10, 10)`` or ``(10, 10, 1)`` do not match the indexed
|
||||
axes and therefore raise errors.
|
||||
|
||||
Reference in New Issue
Block a user