Swap -inf for finite_minimum value (#2029)

This commit is contained in:
Angelos Katharopoulos 2025-03-30 21:55:04 -07:00 committed by GitHub
parent 90823d2938
commit ec2854b13a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 4 deletions

View File

@ -229,7 +229,7 @@ template <
// Init to -Inf
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = Limits<AccumType>::min;
max_score[i] = Limits<AccumType>::finite_min;
}
int kb_lim = params->NK;
@ -273,7 +273,7 @@ template <
if (!align_K && kb == (params->NK_aligned)) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
constexpr auto neg_inf = Limits<selem_t>::finite_min;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
@ -294,7 +294,7 @@ template <
if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
constexpr auto neg_inf = Limits<selem_t>::finite_min;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
@ -317,7 +317,7 @@ template <
if (has_mask) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
constexpr auto neg_inf = Limits<selem_t>::finite_min;
constexpr bool is_bool = is_same_v<MaskType, bool>;
using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;

View File

@ -572,6 +572,34 @@ class TestSDPA(mlx_tests.MLXTestCase):
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
def test_sdpa_nan_bug(self):
N = 128
q_shape = (1, 1, N, 128)
kv_shape = (1, 1, N, 128)
q = mx.random.uniform(shape=q_shape)
k = mx.random.uniform(shape=kv_shape)
v = mx.random.uniform(shape=kv_shape)
# Make boolean window causal mask
linds = rinds = mx.arange(N)
linds = linds[:, None]
rinds = rinds[None]
mask = linds >= rinds
mask = mask & (linds <= rinds + 111)
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0)
expected = mlx_ref_attn(q, k, v, mask=mask, scale=1.0)
self.assertFalse(mx.isnan(out).any().item())
self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4)
# And an additive one
mask = mx.log(mask)
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0)
expected = mlx_ref_attn(q, k, v, mask=mask, scale=1.0)
self.assertFalse(mx.isnan(out).any().item())
self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4)
if __name__ == "__main__":
unittest.main(failfast=True)