From 0e585b4409b0548c2928fc8186be8eb2d513ade6 Mon Sep 17 00:00:00 2001 From: nicolov Date: Thu, 6 Jun 2024 20:51:25 +0200 Subject: [PATCH] Add docstring for scatter (#1189) * Add docstring for scatter * docs nits --------- Co-authored-by: Awni Hannun --- mlx/ops.h | 99 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 1 deletion(-) diff --git a/mlx/ops.h b/mlx/ops.h index 934edf619..069400ba8 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -927,7 +927,104 @@ array take_along_axis( int axis, StreamOrDevice s = {}); -/** Scatter updates to given linear indices */ +/** Scatter updates to the given indices. + * + * The parameters ``indices`` and ``axes`` determine the locations of ``a`` + * that are updated with the values in ``updates``. Assuming 1-d ``indices`` + * for simplicity, ``indices[i]`` are the indices on axis ``axes[i]`` to which + * the values in ``updates`` will be applied. Note each array in + * ``indices`` is assigned to a corresponding axis and hence ``indices.size() == + * axes.size()``. If an index/axis pair is not provided then indices along that + * axis are assumed to be zero. + * + * Note the rank of ``updates`` must be equal to the sum of the rank of the + * broadcasted ``indices`` and the rank of ``a``. In other words, assuming the + * arrays in ``indices`` have the same shape, ``updates.ndim() == + * indices[0].ndim() + a.ndim()``. The leading dimensions of ``updates`` + * correspond to the indices, and the remaining ``a.ndim()`` dimensions are the + * values that will be applied to the given location in ``a``. + * + * For example: + * + * @code + * auto in = zeros({4, 4}, float32); + * auto indices = array({2}); + * auto updates = reshape(arange(1, 3, float32), {1, 1, 2}); + * std::vector axes{0}; + * + * auto out = scatter(in, {indices}, updates, axes); + * @endcode + * + * will produce: + * + * @code + * array([[0, 0, 0, 0], + * [0, 0, 0, 0], + * [1, 2, 0, 0], + * [0, 0, 0, 0]], dtype=float32) + * @endcode + * + * This scatters the two-element row vector ``[1, 2]`` starting at the ``(2, + * 0)`` position of ``a``. + * + * Adding another element to ``indices`` will scatter into another location of + * ``a``. We also have to add an another update for the new index: + * + * @code + * auto in = zeros({4, 4}, float32); + * auto indices = array({2, 0}); + * auto updates = reshape(arange(1, 5, float32), {2, 1, 2}); + * std::vector axes{0}; + * + * auto out = scatter(in, {indices}, updates, axes): + * @endcode + * + * will produce: + * + * @code + * array([[3, 4, 0, 0], + * [0, 0, 0, 0], + * [1, 2, 0, 0], + * [0, 0, 0, 0]], dtype=float32) + * @endcode + * + * To control the scatter location on an additional axis, add another index + * array to ``indices`` and another axis to ``axes``: + * + * @code + * auto in = zeros({4, 4}, float32); + * auto indices = std::vector{array({2, 0}), array({1, 2})}; + * auto updates = reshape(arange(1, 5, float32), {2, 1, 2}); + * std::vector axes{0, 1}; + * + * auto out = scatter(in, indices, updates, axes); + * @endcode + * + * will produce: + * + * @code + * array([[0, 0, 3, 4], + * [0, 0, 0, 0], + * [0, 1, 2, 0], + * [0, 0, 0, 0]], dtype=float32) + * @endcode + * + * Items in indices are broadcasted together. This means: + * + * @code + * auto indices = std::vector{array({2, 0}), array({1})}; + * @endcode + * + * is equivalent to: + * + * @code + * auto indices = std::vector{array({2, 0}), array({1, 1})}; + * @endcode + * + * Note, ``scatter`` does not perform bounds checking on the indices and + * updates. Out-of-bounds accesses on ``a`` are undefined and typically result + * in unintended or invalid memory writes. + */ array scatter( const array& a, const std::vector& indices,