mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Swap -inf for finite_minimum value (#2029)
This commit is contained in:

committed by
GitHub

parent
90823d2938
commit
ec2854b13a
@@ -229,7 +229,7 @@ template <
|
||||
// Init to -Inf
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
max_score[i] = Limits<AccumType>::min;
|
||||
max_score[i] = Limits<AccumType>::finite_min;
|
||||
}
|
||||
|
||||
int kb_lim = params->NK;
|
||||
@@ -273,7 +273,7 @@ template <
|
||||
if (!align_K && kb == (params->NK_aligned)) {
|
||||
using stile_t = decltype(Stile);
|
||||
using selem_t = typename stile_t::elem_type;
|
||||
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
||||
constexpr auto neg_inf = Limits<selem_t>::finite_min;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < stile_t::kTileRows; i++) {
|
||||
@@ -294,7 +294,7 @@ template <
|
||||
if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
|
||||
using stile_t = decltype(Stile);
|
||||
using selem_t = typename stile_t::elem_type;
|
||||
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
||||
constexpr auto neg_inf = Limits<selem_t>::finite_min;
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < stile_t::kTileRows; i++) {
|
||||
@@ -317,7 +317,7 @@ template <
|
||||
if (has_mask) {
|
||||
using stile_t = decltype(Stile);
|
||||
using selem_t = typename stile_t::elem_type;
|
||||
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
||||
constexpr auto neg_inf = Limits<selem_t>::finite_min;
|
||||
|
||||
constexpr bool is_bool = is_same_v<MaskType, bool>;
|
||||
using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;
|
||||
|
Reference in New Issue
Block a user