added tri / tril / triu (#170)

* added tri / tril / triu

* fixed tests

* ctest tests

* tri overload and simplified tests

* changes from comment

* more tests for m

* ensure assert if not 2-D

* remove broadcast_to

* minor tweaks

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Diogo
2023-12-15 20:30:34 -05:00
committed by GitHub
parent 2e02acdc83
commit dc2edc762c
9 changed files with 207 additions and 12 deletions

View File

@@ -218,6 +218,28 @@ array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) {
return eye(n, n, 0, dtype, s);
}
array tri(int n, int m, int k, Dtype type, StreamOrDevice s /* = {} */) {
auto l = expand_dims(arange(n, s), 1, s);
auto r = expand_dims(arange(-k, m - k, s), 0, s);
return astype(greater_equal(l, r, s), type, s);
}
array tril(array x, int k, StreamOrDevice s /* = {} */) {
if (x.ndim() < 2) {
throw std::invalid_argument("[tril] array must be atleast 2-D");
}
auto mask = tri(x.shape(-2), x.shape(-1), k, x.dtype(), s);
return where(mask, x, zeros_like(x, s), s);
}
array triu(array x, int k, StreamOrDevice s /* = {} */) {
if (x.ndim() < 2) {
throw std::invalid_argument("[triu] array must be atleast 2-D");
}
auto mask = tri(x.shape(-2), x.shape(-1), k - 1, x.dtype(), s);
return where(mask, zeros_like(x, s), x, s);
}
array reshape(
const array& a,
std::vector<int> shape,

View File

@@ -110,6 +110,14 @@ inline array identity(int n, StreamOrDevice s = {}) {
return identity(n, float32, s);
}
array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {});
inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
return tri(n, n, 0, type, s);
}
array tril(array x, int k, StreamOrDevice s = {});
array triu(array x, int k, StreamOrDevice s = {});
/** array manipulation */
/** Reshape an array to the given shape. */