Put along axis + fixe for partition grad (#1430)

* put along axis, fixes for partition grad

* zeros for arg reduce
This commit is contained in:
Awni Hannun
2024-09-23 10:03:38 -07:00
committed by GitHub
parent 2b878e9dd7
commit 195b429d99
9 changed files with 220 additions and 9 deletions

View File

@@ -947,6 +947,14 @@ array take_along_axis(
int axis,
StreamOrDevice s = {});
/** Put the values into the array at the given indices along the axis */
array put_along_axis(
const array& a,
const array& indices,
const array& values,
int axis,
StreamOrDevice s = {});
/** Scatter updates to the given indices.
*
* The parameters ``indices`` and ``axes`` determine the locations of ``a``