Add docstring for scatter (#1189)

* Add docstring for scatter

* docs nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
nicolov 2024-06-06 20:51:25 +02:00 committed by GitHub
parent 0163a8e57a
commit 0e585b4409
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -927,7 +927,104 @@ array take_along_axis(
int axis,
StreamOrDevice s = {});
/** Scatter updates to given linear indices */
/** Scatter updates to the given indices.
*
* The parameters ``indices`` and ``axes`` determine the locations of ``a``
* that are updated with the values in ``updates``. Assuming 1-d ``indices``
* for simplicity, ``indices[i]`` are the indices on axis ``axes[i]`` to which
* the values in ``updates`` will be applied. Note each array in
* ``indices`` is assigned to a corresponding axis and hence ``indices.size() ==
* axes.size()``. If an index/axis pair is not provided then indices along that
* axis are assumed to be zero.
*
* Note the rank of ``updates`` must be equal to the sum of the rank of the
* broadcasted ``indices`` and the rank of ``a``. In other words, assuming the
* arrays in ``indices`` have the same shape, ``updates.ndim() ==
* indices[0].ndim() + a.ndim()``. The leading dimensions of ``updates``
* correspond to the indices, and the remaining ``a.ndim()`` dimensions are the
* values that will be applied to the given location in ``a``.
*
* For example:
*
* @code
* auto in = zeros({4, 4}, float32);
* auto indices = array({2});
* auto updates = reshape(arange(1, 3, float32), {1, 1, 2});
* std::vector<int> axes{0};
*
* auto out = scatter(in, {indices}, updates, axes);
* @endcode
*
* will produce:
*
* @code
* array([[0, 0, 0, 0],
* [0, 0, 0, 0],
* [1, 2, 0, 0],
* [0, 0, 0, 0]], dtype=float32)
* @endcode
*
* This scatters the two-element row vector ``[1, 2]`` starting at the ``(2,
* 0)`` position of ``a``.
*
* Adding another element to ``indices`` will scatter into another location of
* ``a``. We also have to add an another update for the new index:
*
* @code
* auto in = zeros({4, 4}, float32);
* auto indices = array({2, 0});
* auto updates = reshape(arange(1, 5, float32), {2, 1, 2});
* std::vector<int> axes{0};
*
* auto out = scatter(in, {indices}, updates, axes):
* @endcode
*
* will produce:
*
* @code
* array([[3, 4, 0, 0],
* [0, 0, 0, 0],
* [1, 2, 0, 0],
* [0, 0, 0, 0]], dtype=float32)
* @endcode
*
* To control the scatter location on an additional axis, add another index
* array to ``indices`` and another axis to ``axes``:
*
* @code
* auto in = zeros({4, 4}, float32);
* auto indices = std::vector{array({2, 0}), array({1, 2})};
* auto updates = reshape(arange(1, 5, float32), {2, 1, 2});
* std::vector<int> axes{0, 1};
*
* auto out = scatter(in, indices, updates, axes);
* @endcode
*
* will produce:
*
* @code
* array([[0, 0, 3, 4],
* [0, 0, 0, 0],
* [0, 1, 2, 0],
* [0, 0, 0, 0]], dtype=float32)
* @endcode
*
* Items in indices are broadcasted together. This means:
*
* @code
* auto indices = std::vector{array({2, 0}), array({1})};
* @endcode
*
* is equivalent to:
*
* @code
* auto indices = std::vector{array({2, 0}), array({1, 1})};
* @endcode
*
* Note, ``scatter`` does not perform bounds checking on the indices and
* updates. Out-of-bounds accesses on ``a`` are undefined and typically result
* in unintended or invalid memory writes.
*/
array scatter(
const array& a,
const std::vector<array>& indices,