Support fused masking in Attention (#1924)

* Update API to allow mask='causal' in fast::sdpa

* Add fallback

* Update steel::AttnParams

* Fix typo

* WIP, basic causal

* Update tests

* Update benchmarking

* Update masking loop limits

* Add bool masking and update tests

* Update additive mask

* Update benchmarks

* Update benchmarks

* Update tests

* Update for bfloat error

* Update early exit

* Add random seed to tests
This commit is contained in:
Jagrit Digani
2025-03-20 11:01:32 -07:00
committed by GitHub
parent 3c164fca8c
commit 9adcd1a650
11 changed files with 504 additions and 148 deletions

View File

@@ -134,7 +134,7 @@ void init_fast(nb::module_& parent_module) {
"memory_efficient_threshold"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
@@ -164,11 +164,11 @@ void init_fast(nb::module_& parent_module) {
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
mask (array, optional): A boolean or additive mask to apply to the
query-key scores. The mask can have at most 4 dimensions and must
be broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. If an
additive mask is given its type must promote to the promoted
type of ``q``, ``k``, and ``v``.
mask (Union[None, str, array], optional): A causal, boolean or additive
mask to apply to the query-key scores. The mask can have at most 4
dimensions and must be broadcast-compatible with the shape
``[B, N, T_q, T_kv]``. If an additive mask is given its type must
promote to the promoted type of ``q``, ``k``, and ``v``.
Returns:
array: The output array.
)pbdoc");

View File

@@ -6,6 +6,91 @@ import mlx_tests
import numpy as np
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
n_kv_heads = k.shape[-3]
n_repeats = n_q_heads // n_kv_heads
B = q.shape[0]
L = q.shape[2]
kL = k.shape[2]
if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
k = mx.expand_dims(k, 2)
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if mask is not None:
if mask == "causal":
q_offset = max(0, kL - L)
q_indices = mx.arange(q_offset, q_offset + L)
k_indices = mx.arange(kL)
mask = q_indices[:, None] >= k_indices[None]
if n_repeats > 1 and mask.ndim >= 3:
if mask.shape[-3] == 1:
mask = mx.expand_dims(mask, -3)
else:
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
if mask.dtype == mx.bool_:
scores = mx.where(mask, scores, -np.float32(np.inf))
else:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = scores @ v
if n_repeats > 1:
out = mx.reshape(out, [B, n_q_heads, L, -1])
return out
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
if transpose:
q_t = mx.transpose(q, (0, 2, 1, 3))
k_t = mx.transpose(k, (0, 2, 1, 3))
v_t = mx.transpose(v, (0, 2, 1, 3))
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
return mx.transpose(o_t, (0, 2, 1, 3))
else:
return f(q, k, v, scale=scale, mask=mask)
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
np.random.seed(0)
np_dtype = getattr(np, dtype)
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
scale = 1.0 / math.sqrt(D)
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
if mask is not None:
if mask == "additive":
mask_np = np.random.normal(0.0, 0.5, (B, qH, qL, kL)).astype(np_dtype)
mask = mx.array(mask_np)
elif mask == "bool":
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
mask = mx.array(mask_np)
return q_mx, k_mx, v_mx, scale, mask
# SDPA for MHA (n_heads == n_kv_heads)
def mlx_primitives_sdpa(q, k, v, scale, mask=None):
p = (q * scale) @ k.transpose(0, 1, 3, 2)
@@ -365,5 +450,84 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
class TestSDPA(mlx_tests.MLXTestCase):
@property
def dtypes(self):
return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
def test_sdpa(self):
if not mx.metal.is_available():
return
# fmt: off
shapes_64 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 128, 128, 64, 32, 32),
( 1, 64, 128, 64, 32, 32),
( 1, 65, 128, 64, 32, 8),
( 1, 64, 127, 64, 32, 8),
( 1, 65, 127, 64, 32, 8),
( 1, 127, 65, 64, 32, 8),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 128, 128, 128, 32, 8),
( 1, 64, 128, 128, 32, 8),
( 1, 65, 127, 128, 32, 8),
( 1, 127, 65, 128, 32, 8),
)
# fmt: on
shapes = shapes_64 + shapes_128
masks = [None, "additive", "bool", "causal"]
transposes = (False, True)
for dtype in self.dtypes:
for t in transposes:
for mask_str in masks:
for B, qL, kL, D, qH, kH in shapes:
with self.subTest(
B=B,
qsl=qL,
ksl=kL,
head_dim=D,
n_q_heads=qH,
n_kv_heads=kH,
mask=mask_str,
transpose=t,
dtype=dtype,
):
np.random.seed(0)
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
B, qL, kL, D, qH, kH, mask_str, t, dtype
)
out_ref = do_attention(
mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, t
)
out_fst = do_attention(
mx.fast.scaled_dot_product_attention,
q_mx,
k_mx,
v_mx,
scale,
mask,
t,
)
atol = 2e-5 if dtype == "float32" else 3e-4
self.assertListEqual(
list(out_ref.shape), list(out_fst.shape)
)
self.assertTrue(
mx.allclose(out_fst, out_ref, atol=atol, rtol=atol)
)
if __name__ == "__main__":
unittest.main(failfast=True)