Allow boolean mask in sdpa (#1753)

* allow boolean mask in sdpa

* more permissive donation in ternary
This commit is contained in:
Awni Hannun
2025-01-06 16:57:07 -08:00
committed by GitHub
parent 25b3a3e541
commit d5ec172c95
4 changed files with 28 additions and 5 deletions

View File

@@ -187,6 +187,17 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
y_hat = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
# Test with boolean causal mask
indices = mx.arange(8)
bool_mask = indices[:, None] >= indices[None]
additive_mask = (~bool_mask).astype(mx.float32) * mx.finfo(mx.float32).min
x = mx.random.normal(shape=(1, 2, 8, 32))
y = mlx_primitives_sdpa_with_gqa(x, x, x, scale, mask=additive_mask)
y_hat = mx.fast.scaled_dot_product_attention(
x, x, x, scale=scale, mask=bool_mask
)
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
if __name__ == "__main__":
unittest.main(failfast=True)