Adds round op and primitive (#203)

This commit is contained in:
Angelos Katharopoulos
2023-12-18 11:32:48 -08:00
committed by GitHub
parent 477397bc98
commit 4d4af12c6f
17 changed files with 187 additions and 2 deletions

View File

@@ -1148,5 +1148,15 @@ void init_array(py::module_& m) {
"reverse"_a = false,
"inclusive"_a = true,
"stream"_a = none,
"See :func:`cummin`.");
"See :func:`cummin`.")
.def(
"round",
[](const array& a, int decimals, StreamOrDevice s) {
return round(a, decimals, s);
},
py::pos_only(),
"decimals"_a = 0,
py::kw_only(),
"stream"_a = none,
"See :func:`round`.");
}

View File

@@ -2922,4 +2922,33 @@ void init_ops(py::module_& m) {
Returns:
result (array): The output containing elements selected from ``x`` and ``y``.
)pbdoc");
m.def(
"round",
[](const array& a, int decimals, StreamOrDevice s) {
return round(a, decimals, s);
},
"a"_a,
py::pos_only(),
"decimals"_a = 0,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
round(a: array, /, decimals: int = 0, stream: Union[None, Stream, Device] = None) -> array
Round to the given number of decimals.
Bascially performs:
.. code-block:: python
s = 10**decimals
x = round(x * s) / s
Args:
a (array): Input array
decimals (int): Number of decimal places to round to. (default: 0)
Returns:
result (array): An array of the same type as ``a`` rounded to the given number of decimals.
)pbdoc");
}