added tri / tril / triu (#170)

* added tri / tril / triu

* fixed tests

* ctest tests

* tri overload and simplified tests

* changes from comment

* more tests for m

* ensure assert if not 2-D

* remove broadcast_to

* minor tweaks

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Diogo
2023-12-15 20:30:34 -05:00
committed by GitHub
parent 2e02acdc83
commit dc2edc762c
9 changed files with 207 additions and 12 deletions

View File

@@ -1410,6 +1410,72 @@ void init_ops(py::module_& m) {
Returns:
array: An identity matrix of size n x n.
)pbdoc");
m.def(
"tri",
[](int n, std::optional<int> m, int k, Dtype dtype, StreamOrDevice s) {
return tri(n, m.value_or(n), k, float32, s);
},
"n"_a,
"m"_a = none,
"k"_a = 0,
"dtype"_a = float32,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
tri(n: int, m: int, k: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array
An array with ones at and below the given diagonal and zeros elsewhere.
Args:
n (int): The number of rows in the output.
m (int, optional): The number of cols in the output. Defaults to ``None``.
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
dtype (Dtype, optional): Data type of the output array. Defaults to ``float32``.
stream (Stream, optional): Stream or device. Defaults to ``None``.
Returns:
array: Array with its lower triangle filled with ones and zeros elsewhere
)pbdoc");
m.def(
"tril",
&tril,
"x"_a,
"k"_a = 0,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
tril(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array
Zeros the array above the given diagonal.
Args:
x (array): input array.
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
stream (Stream, optional): Stream or device. Defaults to ``None``.
Returns:
array: Array zeroed above the given diagonal
)pbdoc");
m.def(
"triu",
&triu,
"x"_a,
"k"_a = 0,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
triu(x: array, k: int, *, stream: Union[None, Stream, Device] = None) -> array
Zeros the array below the given diagonal.
Args:
x (array): input array.
k (int, optional): The diagonal of the 2-D array. Defaults to ``0``.
stream (Stream, optional): Stream or device. Defaults to ``None``.
Returns:
array: Array zeroed below the given diagonal
)pbdoc");
m.def(
"allclose",
&allclose,
@@ -2254,7 +2320,7 @@ void init_ops(py::module_& m) {
Args:
arrays (list(array)): A list of arrays to stack.
axis (int, optional): The axis in the result array along which the
input arrays are stacked. Defaults to ``0``.
input arrays are stacked. Defaults to ``0``.
stream (Stream, optional): Stream or device. Defaults to ``None``.
Returns: