mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-15 21:21:16 +08:00
Add docstring for scatter (#1189)
* Add docstring for scatter * docs nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
0163a8e57a
commit
0e585b4409
99
mlx/ops.h
99
mlx/ops.h
@ -927,7 +927,104 @@ array take_along_axis(
|
|||||||
int axis,
|
int axis,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Scatter updates to given linear indices */
|
/** Scatter updates to the given indices.
|
||||||
|
*
|
||||||
|
* The parameters ``indices`` and ``axes`` determine the locations of ``a``
|
||||||
|
* that are updated with the values in ``updates``. Assuming 1-d ``indices``
|
||||||
|
* for simplicity, ``indices[i]`` are the indices on axis ``axes[i]`` to which
|
||||||
|
* the values in ``updates`` will be applied. Note each array in
|
||||||
|
* ``indices`` is assigned to a corresponding axis and hence ``indices.size() ==
|
||||||
|
* axes.size()``. If an index/axis pair is not provided then indices along that
|
||||||
|
* axis are assumed to be zero.
|
||||||
|
*
|
||||||
|
* Note the rank of ``updates`` must be equal to the sum of the rank of the
|
||||||
|
* broadcasted ``indices`` and the rank of ``a``. In other words, assuming the
|
||||||
|
* arrays in ``indices`` have the same shape, ``updates.ndim() ==
|
||||||
|
* indices[0].ndim() + a.ndim()``. The leading dimensions of ``updates``
|
||||||
|
* correspond to the indices, and the remaining ``a.ndim()`` dimensions are the
|
||||||
|
* values that will be applied to the given location in ``a``.
|
||||||
|
*
|
||||||
|
* For example:
|
||||||
|
*
|
||||||
|
* @code
|
||||||
|
* auto in = zeros({4, 4}, float32);
|
||||||
|
* auto indices = array({2});
|
||||||
|
* auto updates = reshape(arange(1, 3, float32), {1, 1, 2});
|
||||||
|
* std::vector<int> axes{0};
|
||||||
|
*
|
||||||
|
* auto out = scatter(in, {indices}, updates, axes);
|
||||||
|
* @endcode
|
||||||
|
*
|
||||||
|
* will produce:
|
||||||
|
*
|
||||||
|
* @code
|
||||||
|
* array([[0, 0, 0, 0],
|
||||||
|
* [0, 0, 0, 0],
|
||||||
|
* [1, 2, 0, 0],
|
||||||
|
* [0, 0, 0, 0]], dtype=float32)
|
||||||
|
* @endcode
|
||||||
|
*
|
||||||
|
* This scatters the two-element row vector ``[1, 2]`` starting at the ``(2,
|
||||||
|
* 0)`` position of ``a``.
|
||||||
|
*
|
||||||
|
* Adding another element to ``indices`` will scatter into another location of
|
||||||
|
* ``a``. We also have to add an another update for the new index:
|
||||||
|
*
|
||||||
|
* @code
|
||||||
|
* auto in = zeros({4, 4}, float32);
|
||||||
|
* auto indices = array({2, 0});
|
||||||
|
* auto updates = reshape(arange(1, 5, float32), {2, 1, 2});
|
||||||
|
* std::vector<int> axes{0};
|
||||||
|
*
|
||||||
|
* auto out = scatter(in, {indices}, updates, axes):
|
||||||
|
* @endcode
|
||||||
|
*
|
||||||
|
* will produce:
|
||||||
|
*
|
||||||
|
* @code
|
||||||
|
* array([[3, 4, 0, 0],
|
||||||
|
* [0, 0, 0, 0],
|
||||||
|
* [1, 2, 0, 0],
|
||||||
|
* [0, 0, 0, 0]], dtype=float32)
|
||||||
|
* @endcode
|
||||||
|
*
|
||||||
|
* To control the scatter location on an additional axis, add another index
|
||||||
|
* array to ``indices`` and another axis to ``axes``:
|
||||||
|
*
|
||||||
|
* @code
|
||||||
|
* auto in = zeros({4, 4}, float32);
|
||||||
|
* auto indices = std::vector{array({2, 0}), array({1, 2})};
|
||||||
|
* auto updates = reshape(arange(1, 5, float32), {2, 1, 2});
|
||||||
|
* std::vector<int> axes{0, 1};
|
||||||
|
*
|
||||||
|
* auto out = scatter(in, indices, updates, axes);
|
||||||
|
* @endcode
|
||||||
|
*
|
||||||
|
* will produce:
|
||||||
|
*
|
||||||
|
* @code
|
||||||
|
* array([[0, 0, 3, 4],
|
||||||
|
* [0, 0, 0, 0],
|
||||||
|
* [0, 1, 2, 0],
|
||||||
|
* [0, 0, 0, 0]], dtype=float32)
|
||||||
|
* @endcode
|
||||||
|
*
|
||||||
|
* Items in indices are broadcasted together. This means:
|
||||||
|
*
|
||||||
|
* @code
|
||||||
|
* auto indices = std::vector{array({2, 0}), array({1})};
|
||||||
|
* @endcode
|
||||||
|
*
|
||||||
|
* is equivalent to:
|
||||||
|
*
|
||||||
|
* @code
|
||||||
|
* auto indices = std::vector{array({2, 0}), array({1, 1})};
|
||||||
|
* @endcode
|
||||||
|
*
|
||||||
|
* Note, ``scatter`` does not perform bounds checking on the indices and
|
||||||
|
* updates. Out-of-bounds accesses on ``a`` are undefined and typically result
|
||||||
|
* in unintended or invalid memory writes.
|
||||||
|
*/
|
||||||
array scatter(
|
array scatter(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::vector<array>& indices,
|
const std::vector<array>& indices,
|
||||||
|
Loading…
Reference in New Issue
Block a user