mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
added tri / tril / triu (#170)
* added tri / tril / triu * fixed tests * ctest tests * tri overload and simplified tests * changes from comment * more tests for m * ensure assert if not 2-D * remove broadcast_to * minor tweaks --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
22
mlx/ops.cpp
22
mlx/ops.cpp
@@ -218,6 +218,28 @@ array identity(int n, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
return eye(n, n, 0, dtype, s);
|
||||
}
|
||||
|
||||
array tri(int n, int m, int k, Dtype type, StreamOrDevice s /* = {} */) {
|
||||
auto l = expand_dims(arange(n, s), 1, s);
|
||||
auto r = expand_dims(arange(-k, m - k, s), 0, s);
|
||||
return astype(greater_equal(l, r, s), type, s);
|
||||
}
|
||||
|
||||
array tril(array x, int k, StreamOrDevice s /* = {} */) {
|
||||
if (x.ndim() < 2) {
|
||||
throw std::invalid_argument("[tril] array must be atleast 2-D");
|
||||
}
|
||||
auto mask = tri(x.shape(-2), x.shape(-1), k, x.dtype(), s);
|
||||
return where(mask, x, zeros_like(x, s), s);
|
||||
}
|
||||
|
||||
array triu(array x, int k, StreamOrDevice s /* = {} */) {
|
||||
if (x.ndim() < 2) {
|
||||
throw std::invalid_argument("[triu] array must be atleast 2-D");
|
||||
}
|
||||
auto mask = tri(x.shape(-2), x.shape(-1), k - 1, x.dtype(), s);
|
||||
return where(mask, zeros_like(x, s), x, s);
|
||||
}
|
||||
|
||||
array reshape(
|
||||
const array& a,
|
||||
std::vector<int> shape,
|
||||
|
@@ -110,6 +110,14 @@ inline array identity(int n, StreamOrDevice s = {}) {
|
||||
return identity(n, float32, s);
|
||||
}
|
||||
|
||||
array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {});
|
||||
inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
|
||||
return tri(n, n, 0, type, s);
|
||||
}
|
||||
|
||||
array tril(array x, int k, StreamOrDevice s = {});
|
||||
array triu(array x, int k, StreamOrDevice s = {});
|
||||
|
||||
/** array manipulation */
|
||||
|
||||
/** Reshape an array to the given shape. */
|
||||
|
Reference in New Issue
Block a user