Swap -inf for finite_minimum value (#2029)

This commit is contained in:
Angelos Katharopoulos
2025-03-30 21:55:04 -07:00
committed by GitHub
parent 90823d2938
commit ec2854b13a
2 changed files with 32 additions and 4 deletions

View File

@@ -229,7 +229,7 @@ template <
// Init to -Inf
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = Limits<AccumType>::min;
max_score[i] = Limits<AccumType>::finite_min;
}
int kb_lim = params->NK;
@@ -273,7 +273,7 @@ template <
if (!align_K && kb == (params->NK_aligned)) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
constexpr auto neg_inf = Limits<selem_t>::finite_min;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
@@ -294,7 +294,7 @@ template <
if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
constexpr auto neg_inf = Limits<selem_t>::finite_min;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
@@ -317,7 +317,7 @@ template <
if (has_mask) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
constexpr auto neg_inf = Limits<selem_t>::finite_min;
constexpr bool is_bool = is_same_v<MaskType, bool>;
using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;