mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Swap -inf for finite_minimum value (#2029)
This commit is contained in:
parent
90823d2938
commit
ec2854b13a
@ -229,7 +229,7 @@ template <
|
|||||||
// Init to -Inf
|
// Init to -Inf
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short i = 0; i < kRowsPT; ++i) {
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
max_score[i] = Limits<AccumType>::min;
|
max_score[i] = Limits<AccumType>::finite_min;
|
||||||
}
|
}
|
||||||
|
|
||||||
int kb_lim = params->NK;
|
int kb_lim = params->NK;
|
||||||
@ -273,7 +273,7 @@ template <
|
|||||||
if (!align_K && kb == (params->NK_aligned)) {
|
if (!align_K && kb == (params->NK_aligned)) {
|
||||||
using stile_t = decltype(Stile);
|
using stile_t = decltype(Stile);
|
||||||
using selem_t = typename stile_t::elem_type;
|
using selem_t = typename stile_t::elem_type;
|
||||||
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
constexpr auto neg_inf = Limits<selem_t>::finite_min;
|
||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short i = 0; i < stile_t::kTileRows; i++) {
|
for (short i = 0; i < stile_t::kTileRows; i++) {
|
||||||
@ -294,7 +294,7 @@ template <
|
|||||||
if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
|
if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
|
||||||
using stile_t = decltype(Stile);
|
using stile_t = decltype(Stile);
|
||||||
using selem_t = typename stile_t::elem_type;
|
using selem_t = typename stile_t::elem_type;
|
||||||
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
constexpr auto neg_inf = Limits<selem_t>::finite_min;
|
||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short i = 0; i < stile_t::kTileRows; i++) {
|
for (short i = 0; i < stile_t::kTileRows; i++) {
|
||||||
@ -317,7 +317,7 @@ template <
|
|||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
using stile_t = decltype(Stile);
|
using stile_t = decltype(Stile);
|
||||||
using selem_t = typename stile_t::elem_type;
|
using selem_t = typename stile_t::elem_type;
|
||||||
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
|
constexpr auto neg_inf = Limits<selem_t>::finite_min;
|
||||||
|
|
||||||
constexpr bool is_bool = is_same_v<MaskType, bool>;
|
constexpr bool is_bool = is_same_v<MaskType, bool>;
|
||||||
using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;
|
using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;
|
||||||
|
@ -572,6 +572,34 @@ class TestSDPA(mlx_tests.MLXTestCase):
|
|||||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
def test_sdpa_nan_bug(self):
|
||||||
|
N = 128
|
||||||
|
q_shape = (1, 1, N, 128)
|
||||||
|
kv_shape = (1, 1, N, 128)
|
||||||
|
q = mx.random.uniform(shape=q_shape)
|
||||||
|
k = mx.random.uniform(shape=kv_shape)
|
||||||
|
v = mx.random.uniform(shape=kv_shape)
|
||||||
|
|
||||||
|
# Make boolean window causal mask
|
||||||
|
linds = rinds = mx.arange(N)
|
||||||
|
linds = linds[:, None]
|
||||||
|
rinds = rinds[None]
|
||||||
|
mask = linds >= rinds
|
||||||
|
mask = mask & (linds <= rinds + 111)
|
||||||
|
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0)
|
||||||
|
expected = mlx_ref_attn(q, k, v, mask=mask, scale=1.0)
|
||||||
|
self.assertFalse(mx.isnan(out).any().item())
|
||||||
|
self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4)
|
||||||
|
|
||||||
|
# And an additive one
|
||||||
|
mask = mx.log(mask)
|
||||||
|
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0)
|
||||||
|
expected = mlx_ref_attn(q, k, v, mask=mask, scale=1.0)
|
||||||
|
self.assertFalse(mx.isnan(out).any().item())
|
||||||
|
self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(failfast=True)
|
unittest.main(failfast=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user