mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-30 23:38:09 +08:00
Allow boolean mask in sdpa (#1753)
* allow boolean mask in sdpa * more permissive donation in ternary
This commit is contained in:
@@ -164,9 +164,11 @@ void init_fast(nb::module_& parent_module) {
|
||||
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
|
||||
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
|
||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
||||
mask (array, optional): An additive mask to apply to the query-key
|
||||
scores. The mask can have at most 4 dimensions and must be
|
||||
broadcast-compatible with the shape ``[B, N, T_q, T_kv]``.
|
||||
mask (array, optional): A boolean or additive mask to apply to the
|
||||
query-key scores. The mask can have at most 4 dimensions and must
|
||||
be broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. If an
|
||||
additive mask is given its type must promote to the promoted
|
||||
type of ``q``, ``k``, and ``v``.
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
|
||||
@@ -187,6 +187,17 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
y_hat = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
|
||||
|
||||
# Test with boolean causal mask
|
||||
indices = mx.arange(8)
|
||||
bool_mask = indices[:, None] >= indices[None]
|
||||
additive_mask = (~bool_mask).astype(mx.float32) * mx.finfo(mx.float32).min
|
||||
x = mx.random.normal(shape=(1, 2, 8, 32))
|
||||
y = mlx_primitives_sdpa_with_gqa(x, x, x, scale, mask=additive_mask)
|
||||
y_hat = mx.fast.scaled_dot_product_attention(
|
||||
x, x, x, scale=scale, mask=bool_mask
|
||||
)
|
||||
self.assertTrue(mx.allclose(y, y_hat, atol=atol))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
|
||||
Reference in New Issue
Block a user