mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 10:27:41 +08:00
spelling: at least
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
This commit is contained in:
parent
27f640e7c9
commit
5097cd7bca
@ -40,7 +40,7 @@ void all_reduce_dispatch(
|
||||
// Set grid dimensions
|
||||
|
||||
// We make sure each thread has enough to do by making it read in
|
||||
// atleast n_reads inputs
|
||||
// at least n_reads inputs
|
||||
int n_reads = REDUCE_N_READS;
|
||||
|
||||
// mod_in_size gives us the groups of n_reads needed to go over the entire
|
||||
|
@ -247,7 +247,7 @@ array tri(int n, int m, int k, Dtype type, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array tril(array x, int k, StreamOrDevice s /* = {} */) {
|
||||
if (x.ndim() < 2) {
|
||||
throw std::invalid_argument("[tril] array must be atleast 2-D");
|
||||
throw std::invalid_argument("[tril] array must be at least 2-D");
|
||||
}
|
||||
auto mask = tri(x.shape(-2), x.shape(-1), k, x.dtype(), s);
|
||||
return where(mask, x, zeros_like(x, s), s);
|
||||
@ -255,7 +255,7 @@ array tril(array x, int k, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array triu(array x, int k, StreamOrDevice s /* = {} */) {
|
||||
if (x.ndim() < 2) {
|
||||
throw std::invalid_argument("[triu] array must be atleast 2-D");
|
||||
throw std::invalid_argument("[triu] array must be at least 2-D");
|
||||
}
|
||||
auto mask = tri(x.shape(-2), x.shape(-1), k - 1, x.dtype(), s);
|
||||
return where(mask, zeros_like(x, s), x, s);
|
||||
|
Loading…
Reference in New Issue
Block a user