Add the roll op (#1455)

This commit is contained in:
Angelos Katharopoulos
2024-10-07 17:21:42 -07:00
committed by GitHub
parent f374b6ca4d
commit 9b12093739
6 changed files with 231 additions and 0 deletions

View File

@@ -4795,4 +4795,54 @@ void init_ops(nb::module_& m) {
Returns:
array: The output array.
)pbdoc");
m.def(
"roll",
[](const array& a,
const IntOrVec& shift,
const IntOrVec& axis,
StreamOrDevice s) {
return std::visit(
[&](auto sh, auto ax) -> array {
using T = decltype(ax);
using V = decltype(sh);
if constexpr (std::is_same_v<V, std::monostate>) {
throw std::invalid_argument(
"[roll] Expected two arguments but only one was given.");
} else {
if constexpr (std::is_same_v<T, std::monostate>) {
return roll(a, sh, s);
} else {
return roll(a, sh, ax, s);
}
}
},
shift,
axis);
},
nb::arg(),
"shift"_a,
"axis"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def roll(a: array, shift: Union[int, Tuple[int]], axis: Union[None, int, Tuple[int]] = None, /, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Roll array elements along a given axis.
Elements that are rolled beyond the end of the array are introduced at
the beggining and vice-versa.
If the axis is not provided the array is flattened, rolled and then the
shape is restored.
Args:
a (array): Input array
shift (int or tuple(int)): The number of places by which elements
are shifted. If positive the array is rolled to the right, if
negative it is rolled to the left. If an int is provided but the
axis is a tuple then the same value is used for all axes.
axis (int or tuple(int), optional): The axis or axes along which to
roll the elements.
)pbdoc");
}