mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 16:13:52 +08:00
Support fused masking in Attention (#1924)
* Update API to allow mask='causal' in fast::sdpa * Add fallback * Update steel::AttnParams * Fix typo * WIP, basic causal * Update tests * Update benchmarking * Update masking loop limits * Add bool masking and update tests * Update additive mask * Update benchmarks * Update benchmarks * Update tests * Update for bfloat error * Update early exit * Add random seed to tests
This commit is contained in:
@@ -134,7 +134,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
"memory_efficient_threshold"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
|
||||
|
||||
@@ -164,11 +164,11 @@ void init_fast(nb::module_& parent_module) {
|
||||
k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
|
||||
v (array): Values with shape ``[B, N_kv, T_kv, D]``.
|
||||
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
|
||||
mask (array, optional): A boolean or additive mask to apply to the
|
||||
query-key scores. The mask can have at most 4 dimensions and must
|
||||
be broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. If an
|
||||
additive mask is given its type must promote to the promoted
|
||||
type of ``q``, ``k``, and ``v``.
|
||||
mask (Union[None, str, array], optional): A causal, boolean or additive
|
||||
mask to apply to the query-key scores. The mask can have at most 4
|
||||
dimensions and must be broadcast-compatible with the shape
|
||||
``[B, N, T_q, T_kv]``. If an additive mask is given its type must
|
||||
promote to the promoted type of ``q``, ``k``, and ``v``.
|
||||
Returns:
|
||||
array: The output array.
|
||||
)pbdoc");
|
||||
|
@@ -6,6 +6,91 @@ import mlx_tests
|
||||
import numpy as np
|
||||
|
||||
|
||||
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
q_dtype = q.dtype
|
||||
q = q * mx.array(scale, q_dtype)
|
||||
n_q_heads = q.shape[-3]
|
||||
n_kv_heads = k.shape[-3]
|
||||
n_repeats = n_q_heads // n_kv_heads
|
||||
|
||||
B = q.shape[0]
|
||||
L = q.shape[2]
|
||||
kL = k.shape[2]
|
||||
|
||||
if n_repeats > 1:
|
||||
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||
k = mx.expand_dims(k, 2)
|
||||
v = mx.expand_dims(v, 2)
|
||||
|
||||
scores = q @ mx.swapaxes(k, -1, -2)
|
||||
|
||||
if mask is not None:
|
||||
|
||||
if mask == "causal":
|
||||
q_offset = max(0, kL - L)
|
||||
q_indices = mx.arange(q_offset, q_offset + L)
|
||||
k_indices = mx.arange(kL)
|
||||
mask = q_indices[:, None] >= k_indices[None]
|
||||
|
||||
if n_repeats > 1 and mask.ndim >= 3:
|
||||
if mask.shape[-3] == 1:
|
||||
mask = mx.expand_dims(mask, -3)
|
||||
else:
|
||||
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
|
||||
|
||||
if mask.dtype == mx.bool_:
|
||||
scores = mx.where(mask, scores, -np.float32(np.inf))
|
||||
else:
|
||||
scores += mask
|
||||
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
|
||||
out = scores @ v
|
||||
if n_repeats > 1:
|
||||
out = mx.reshape(out, [B, n_q_heads, L, -1])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
|
||||
if transpose:
|
||||
q_t = mx.transpose(q, (0, 2, 1, 3))
|
||||
k_t = mx.transpose(k, (0, 2, 1, 3))
|
||||
v_t = mx.transpose(v, (0, 2, 1, 3))
|
||||
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
|
||||
return mx.transpose(o_t, (0, 2, 1, 3))
|
||||
else:
|
||||
return f(q, k, v, scale=scale, mask=mask)
|
||||
|
||||
|
||||
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
|
||||
np.random.seed(0)
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
|
||||
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
|
||||
|
||||
scale = 1.0 / math.sqrt(D)
|
||||
|
||||
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
|
||||
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||
|
||||
q_mx = mx.array(q_np)
|
||||
k_mx = mx.array(k_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
if mask is not None:
|
||||
if mask == "additive":
|
||||
mask_np = np.random.normal(0.0, 0.5, (B, qH, qL, kL)).astype(np_dtype)
|
||||
mask = mx.array(mask_np)
|
||||
elif mask == "bool":
|
||||
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
|
||||
mask = mx.array(mask_np)
|
||||
|
||||
return q_mx, k_mx, v_mx, scale, mask
|
||||
|
||||
|
||||
# SDPA for MHA (n_heads == n_kv_heads)
|
||||
def mlx_primitives_sdpa(q, k, v, scale, mask=None):
|
||||
p = (q * scale) @ k.transpose(0, 1, 3, 2)
|
||||
@@ -365,5 +450,84 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
|
||||
|
||||
|
||||
class TestSDPA(mlx_tests.MLXTestCase):
|
||||
@property
|
||||
def dtypes(self):
|
||||
return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
|
||||
|
||||
def test_sdpa(self):
|
||||
if not mx.metal.is_available():
|
||||
return
|
||||
|
||||
# fmt: off
|
||||
shapes_64 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 128, 128, 64, 32, 32),
|
||||
( 1, 64, 128, 64, 32, 32),
|
||||
( 1, 65, 128, 64, 32, 8),
|
||||
( 1, 64, 127, 64, 32, 8),
|
||||
( 1, 65, 127, 64, 32, 8),
|
||||
( 1, 127, 65, 64, 32, 8),
|
||||
)
|
||||
|
||||
shapes_128 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 128, 128, 128, 32, 8),
|
||||
( 1, 64, 128, 128, 32, 8),
|
||||
( 1, 65, 127, 128, 32, 8),
|
||||
( 1, 127, 65, 128, 32, 8),
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
shapes = shapes_64 + shapes_128
|
||||
masks = [None, "additive", "bool", "causal"]
|
||||
transposes = (False, True)
|
||||
|
||||
for dtype in self.dtypes:
|
||||
for t in transposes:
|
||||
for mask_str in masks:
|
||||
for B, qL, kL, D, qH, kH in shapes:
|
||||
with self.subTest(
|
||||
B=B,
|
||||
qsl=qL,
|
||||
ksl=kL,
|
||||
head_dim=D,
|
||||
n_q_heads=qH,
|
||||
n_kv_heads=kH,
|
||||
mask=mask_str,
|
||||
transpose=t,
|
||||
dtype=dtype,
|
||||
):
|
||||
|
||||
np.random.seed(0)
|
||||
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
|
||||
B, qL, kL, D, qH, kH, mask_str, t, dtype
|
||||
)
|
||||
|
||||
out_ref = do_attention(
|
||||
mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, t
|
||||
)
|
||||
|
||||
out_fst = do_attention(
|
||||
mx.fast.scaled_dot_product_attention,
|
||||
q_mx,
|
||||
k_mx,
|
||||
v_mx,
|
||||
scale,
|
||||
mask,
|
||||
t,
|
||||
)
|
||||
|
||||
atol = 2e-5 if dtype == "float32" else 3e-4
|
||||
|
||||
self.assertListEqual(
|
||||
list(out_ref.shape), list(out_fst.shape)
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
mx.allclose(out_fst, out_ref, atol=atol, rtol=atol)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(failfast=True)
|
||||
|
Reference in New Issue
Block a user