added atleast *args input support (#710)

* added atleast list(array) input support

* function overloading implemented

* Refactoring

* fixed formatting

* removed pos_only
This commit is contained in:
Hinrik Snær Guðmundsson
2024-02-26 14:17:59 -05:00
committed by GitHub
parent 3b661b7394
commit 08226ab491
5 changed files with 131 additions and 30 deletions

View File

@@ -3638,62 +3638,69 @@ void init_ops(py::module_& m) {
)pbdoc");
m.def(
"atleast_1d",
&atleast_1d,
"a"_a,
py::pos_only(),
[](const py::args& arys, StreamOrDevice s) -> py::object {
if (arys.size() == 1) {
return py::cast(atleast_1d(arys[0].cast<array>(), s));
}
return py::cast(atleast_1d(arys.cast<std::vector<array>>(), s));
},
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_1d(a: array, stream: Union[None, Stream, Device] = None) -> array
atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
Convert array to have at least one dimension.
Convert all arrays to have at least one dimension.
args:
a (array): Input array
Args:
*arys: Input arrays.
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
Returns:
array: An array with at least one dimension.
array or list(array): An array or list of arrays with at least one dimension.
)pbdoc");
m.def(
"atleast_2d",
&atleast_2d,
"a"_a,
py::pos_only(),
[](const py::args& arys, StreamOrDevice s) -> py::object {
if (arys.size() == 1) {
return py::cast(atleast_2d(arys[0].cast<array>(), s));
}
return py::cast(atleast_2d(arys.cast<std::vector<array>>(), s));
},
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_2d(a: array, stream: Union[None, Stream, Device] = None) -> array
atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
Convert array to have at least two dimensions.
Convert all arrays to have at least two dimensions.
args:
a (array): Input array
Args:
*arys: Input arrays.
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
Returns:
array: An array with at least two dimensions.
array or list(array): An array or list of arrays with at least two dimensions.
)pbdoc");
m.def(
"atleast_3d",
&atleast_3d,
"a"_a,
py::pos_only(),
[](const py::args& arys, StreamOrDevice s) -> py::object {
if (arys.size() == 1) {
return py::cast(atleast_3d(arys[0].cast<array>(), s));
}
return py::cast(atleast_3d(arys.cast<std::vector<array>>(), s));
},
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array
atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]
Convert array to have at least three dimensions.
Convert all arrays to have at least three dimensions.
args:
a (array): Input array
Args:
*arys: Input arrays.
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.
Returns:
array: An array with at least three dimensions.
array or list(array): An array or list of arrays with at least three dimensions.
)pbdoc");
}