mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-29 13:55:29 +08:00
full row mask in sdpa consistently gives nan (#2406)
This commit is contained in:
parent
0f5ce173da
commit
e1840853ce
@ -708,7 +708,10 @@ array scaled_dot_product_attention(
|
|||||||
}
|
}
|
||||||
if (mask.dtype() == bool_) {
|
if (mask.dtype() == bool_) {
|
||||||
scores = where(
|
scores = where(
|
||||||
mask, scores, array(finfo(scores.dtype()).min, scores.dtype()));
|
mask,
|
||||||
|
scores,
|
||||||
|
array(-std::numeric_limits<float>::infinity(), scores.dtype()),
|
||||||
|
s);
|
||||||
} else {
|
} else {
|
||||||
scores = add(scores, mask, s);
|
scores = add(scores, mask, s);
|
||||||
}
|
}
|
||||||
|
@ -398,6 +398,18 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||||
|
|
||||||
|
def test_fully_masked(self):
|
||||||
|
Lkv = 8
|
||||||
|
mask = mx.array(False)
|
||||||
|
for D in [4, 128]:
|
||||||
|
for Lq in [1, 8]:
|
||||||
|
q = mx.random.normal(shape=(1, 4, Lq, D))
|
||||||
|
k = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||||
|
v = mx.random.normal(shape=(1, 4, Lkv, D))
|
||||||
|
|
||||||
|
out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1)
|
||||||
|
self.assertTrue(mx.all(mx.isnan(out)))
|
||||||
|
|
||||||
def test_fast_sdpa_few_query(self):
|
def test_fast_sdpa_few_query(self):
|
||||||
D = 64
|
D = 64
|
||||||
L = 43
|
L = 43
|
||||||
|
Loading…
Reference in New Issue
Block a user