mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-27 16:28:10 +08:00
avoid producing NaN in attention (#2608)
This commit is contained in:
@@ -38,7 +38,7 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None, sinks=None):
|
||||
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
|
||||
|
||||
if mask.dtype == mx.bool_:
|
||||
scores = mx.where(mask, scores, -np.float32(np.inf))
|
||||
scores = mx.where(mask, scores, mx.finfo(scores.dtype).min)
|
||||
else:
|
||||
scores += mask
|
||||
|
||||
@@ -410,18 +410,15 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_fully_masked(self):
|
||||
Lkv = 8
|
||||
masks = [mx.array(False), mx.array(-float("inf"))]
|
||||
for mask in masks:
|
||||
for D in [4, 128]:
|
||||
for Lq in [1, 8]:
|
||||
q = mx.random.normal(shape=(1, 4, Lq, D))
|
||||
k = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||
v = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||
mask = mx.array(False)
|
||||
for D in [128]:
|
||||
for Lq in [1, 8, 32]:
|
||||
q = mx.random.normal(shape=(1, 4, Lq, D))
|
||||
k = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||
v = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, mask=mask, scale=1
|
||||
)
|
||||
self.assertTrue(mx.all(mx.isnan(out)))
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1)
|
||||
self.assertFalse(mx.any(mx.isnan(out)))
|
||||
|
||||
def test_inf_score(self):
|
||||
Lkv = 8
|
||||
|
Reference in New Issue
Block a user