Add move and swap axis, and vmap for slice, concat, and gather (#158)

* add move and swap axis, and vmap for slice, concat, and gather
This commit is contained in:
Awni Hannun
2023-12-14 12:59:12 -08:00
committed by GitHub
parent f55908bc48
commit e5851e52b1
10 changed files with 399 additions and 7 deletions

View File

@@ -1591,6 +1591,50 @@ void init_ops(py::module_& m) {
Returns:
array: The ceil of ``a``.
)pbdoc");
m.def(
"moveaxis",
&moveaxis,
"a"_a,
py::pos_only(),
"source"_a,
"destiantion"_a,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
moveaxis(a: array, /, source: int, destination: int, *, stream: Union[None, Stream, Device] = None) -> array
Move an axis to a new position.
Args:
a (array): Input array.
source (int): Specifies the source axis.
destination (int): Specifies the destination axis.
Returns:
array: The array with the axis moved.
)pbdoc");
m.def(
"swapaxes",
&swapaxes,
"a"_a,
py::pos_only(),
"axis1"_a,
"axis2"_a,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
swapaxes(a: array, /, axis1 : int, axis2: int, *, stream: Union[None, Stream, Device] = None) -> array
Swap two axes of an array.
Args:
a (array): Input array.
axis1 (int): Specifies the first axis.
axis2 (int): Specifies the second axis.
Returns:
array: The array with swapped axes.
)pbdoc");
m.def(
"transpose",
[](const array& a,