scatter axis + gather axis primitives (#1813)

* scatter axis + gather axis primitives

* add transforms

* comment
This commit is contained in:
Awni Hannun
2025-01-31 20:48:08 -08:00
committed by GitHub
parent c6fc07f1f4
commit b7c9f1d38f
15 changed files with 1037 additions and 85 deletions

View File

@@ -968,6 +968,14 @@ array put_along_axis(
int axis,
StreamOrDevice s = {});
/** Add the values into the array at the given indices along the axis */
array scatter_add_axis(
const array& a,
const array& indices,
const array& values,
int axis,
StreamOrDevice s = {});
/** Scatter updates to the given indices.
*
* The parameters ``indices`` and ``axes`` determine the locations of ``a``