Put along axis + fixe for partition grad (#1430)

* put along axis, fixes for partition grad

* zeros for arg reduce
This commit is contained in:
Awni Hannun
2024-09-23 10:03:38 -07:00
committed by GitHub
parent 2b878e9dd7
commit 195b429d99
9 changed files with 220 additions and 9 deletions

View File

@@ -1463,7 +1463,48 @@ void init_ops(nb::module_& m) {
operation.
Returns:
array: The output array with the specified shape and values.
array: The output array.
)pbdoc");
m.def(
"put_along_axis",
[](const array& a,
const array& indices,
const array& values,
const std::optional<int>& axis,
StreamOrDevice s) {
if (axis.has_value()) {
return put_along_axis(a, indices, values, axis.value(), s);
} else {
return reshape(
put_along_axis(reshape(a, {-1}, s), indices, values, 0, s),
a.shape(),
s);
}
},
nb::arg(),
"indices"_a,
"values"_a,
"axis"_a.none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def put_along_axis(a: array, /, indices: array, values: array, axis: Optional[int] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Put values along an axis at the specified indices.
Args:
a (array): Destination array.
indices (array): Indices array. These should be broadcastable with
the input array excluding the `axis` dimension.
values (array): Values array. These should be broadcastable with
the indices.
axis (int or None): Axis in the destination to put the values to. If
``axis == None`` the destination is flattened prior to the put
operation.
Returns:
array: The output array.
)pbdoc");
m.def(
"full",