mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 04:56:41 +08:00
spelling: at least
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
This commit is contained in:
parent
27f640e7c9
commit
5097cd7bca
@ -40,7 +40,7 @@ void all_reduce_dispatch(
|
|||||||
// Set grid dimensions
|
// Set grid dimensions
|
||||||
|
|
||||||
// We make sure each thread has enough to do by making it read in
|
// We make sure each thread has enough to do by making it read in
|
||||||
// atleast n_reads inputs
|
// at least n_reads inputs
|
||||||
int n_reads = REDUCE_N_READS;
|
int n_reads = REDUCE_N_READS;
|
||||||
|
|
||||||
// mod_in_size gives us the groups of n_reads needed to go over the entire
|
// mod_in_size gives us the groups of n_reads needed to go over the entire
|
||||||
|
@ -247,7 +247,7 @@ array tri(int n, int m, int k, Dtype type, StreamOrDevice s /* = {} */) {
|
|||||||
|
|
||||||
array tril(array x, int k, StreamOrDevice s /* = {} */) {
|
array tril(array x, int k, StreamOrDevice s /* = {} */) {
|
||||||
if (x.ndim() < 2) {
|
if (x.ndim() < 2) {
|
||||||
throw std::invalid_argument("[tril] array must be atleast 2-D");
|
throw std::invalid_argument("[tril] array must be at least 2-D");
|
||||||
}
|
}
|
||||||
auto mask = tri(x.shape(-2), x.shape(-1), k, x.dtype(), s);
|
auto mask = tri(x.shape(-2), x.shape(-1), k, x.dtype(), s);
|
||||||
return where(mask, x, zeros_like(x, s), s);
|
return where(mask, x, zeros_like(x, s), s);
|
||||||
@ -255,7 +255,7 @@ array tril(array x, int k, StreamOrDevice s /* = {} */) {
|
|||||||
|
|
||||||
array triu(array x, int k, StreamOrDevice s /* = {} */) {
|
array triu(array x, int k, StreamOrDevice s /* = {} */) {
|
||||||
if (x.ndim() < 2) {
|
if (x.ndim() < 2) {
|
||||||
throw std::invalid_argument("[triu] array must be atleast 2-D");
|
throw std::invalid_argument("[triu] array must be at least 2-D");
|
||||||
}
|
}
|
||||||
auto mask = tri(x.shape(-2), x.shape(-1), k - 1, x.dtype(), s);
|
auto mask = tri(x.shape(-2), x.shape(-1), k - 1, x.dtype(), s);
|
||||||
return where(mask, zeros_like(x, s), x, s);
|
return where(mask, zeros_like(x, s), x, s);
|
||||||
|
Loading…
Reference in New Issue
Block a user