spelling: at least

Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
This commit is contained in:
Josh Soref 2024-01-01 23:03:31 -05:00
parent 27f640e7c9
commit 5097cd7bca
2 changed files with 3 additions and 3 deletions

View File

@ -40,7 +40,7 @@ void all_reduce_dispatch(
// Set grid dimensions
// We make sure each thread has enough to do by making it read in
// atleast n_reads inputs
// at least n_reads inputs
int n_reads = REDUCE_N_READS;
// mod_in_size gives us the groups of n_reads needed to go over the entire

View File

@ -247,7 +247,7 @@ array tri(int n, int m, int k, Dtype type, StreamOrDevice s /* = {} */) {
array tril(array x, int k, StreamOrDevice s /* = {} */) {
if (x.ndim() < 2) {
throw std::invalid_argument("[tril] array must be atleast 2-D");
throw std::invalid_argument("[tril] array must be at least 2-D");
}
auto mask = tri(x.shape(-2), x.shape(-1), k, x.dtype(), s);
return where(mask, x, zeros_like(x, s), s);
@ -255,7 +255,7 @@ array tril(array x, int k, StreamOrDevice s /* = {} */) {
array triu(array x, int k, StreamOrDevice s /* = {} */) {
if (x.ndim() < 2) {
throw std::invalid_argument("[triu] array must be atleast 2-D");
throw std::invalid_argument("[triu] array must be at least 2-D");
}
auto mask = tri(x.shape(-2), x.shape(-1), k - 1, x.dtype(), s);
return where(mask, zeros_like(x, s), x, s);