import math import unittest import mlx.core as mx import mlx_tests import numpy as np def mlx_ref_attn(q, k, v, scale=1.0, mask=None): q_dtype = q.dtype q = q * mx.array(scale, q_dtype) n_q_heads = q.shape[-3] n_kv_heads = k.shape[-3] n_repeats = n_q_heads // n_kv_heads B = q.shape[0] L = q.shape[2] kL = k.shape[2] if n_repeats > 1: q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1]) k = mx.expand_dims(k, 2) v = mx.expand_dims(v, 2) scores = q @ mx.swapaxes(k, -1, -2) if mask is not None: if mask == "causal": q_offset = max(0, kL - L) q_indices = mx.arange(q_offset, q_offset + L) k_indices = mx.arange(kL) mask = q_indices[:, None] >= k_indices[None] if n_repeats > 1 and mask.ndim >= 3: if mask.shape[-3] == 1: mask = mx.expand_dims(mask, -3) else: mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats)) if mask.dtype == mx.bool_: scores = mx.where(mask, scores, -np.float32(np.inf)) else: scores += mask scores = mx.softmax(scores, axis=-1, precise=True) out = scores @ v if n_repeats > 1: out = mx.reshape(out, [B, n_q_heads, L, -1]) return out def do_attention(f, q, k, v, scale, mask=None, transpose=False): if transpose: q_t = mx.transpose(q, (0, 2, 1, 3)) k_t = mx.transpose(k, (0, 2, 1, 3)) v_t = mx.transpose(v, (0, 2, 1, 3)) o_t = f(q_t, k_t, v_t, scale=scale, mask=mask) return mx.transpose(o_t, (0, 2, 1, 3)) else: return f(q, k, v, scale=scale, mask=mask) def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype): np.random.seed(0) np_dtype = getattr(np, dtype) shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D) shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D) scale = 1.0 / math.sqrt(D) q_np = np.random.normal(0.0, 0.5, shape_q).astype(np_dtype) k_np = np.random.normal(0.0, 0.5, shape_kv).astype(np_dtype) v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype) q_mx = mx.array(q_np) k_mx = mx.array(k_np) v_mx = mx.array(v_np) if mask is not None: if mask == "additive": mask_np = np.random.normal(0.0, 0.5, (B, qH, qL, kL)).astype(np_dtype) mask = mx.array(mask_np) elif mask == "bool": mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5 mask = mx.array(mask_np) return q_mx, k_mx, v_mx, scale, mask # SDPA for MHA (n_heads == n_kv_heads) def mlx_primitives_sdpa(q, k, v, scale, mask=None): p = (q * scale) @ k.transpose(0, 1, 3, 2) if mask is not None: if mask == "causal": q_offset = max(0, k.shape[2] - q.shape[2]) q_indices = mx.arange(q_offset, q_offset + q.shape[2]) k_indices = mx.arange(k.shape[2]) mask = q_indices[:, None] >= k_indices[None] p = mx.where(mask, p, mx.finfo(mx.float32).min) elif mask.dtype == mx.bool_: p = mx.where(mask, p, mx.finfo(mx.float32).min) else: p += mask scores = mx.softmax(p.astype(mx.float32), axis=-1).astype(p.dtype) return scores @ v # SDPA for GQA (n_heads > n_kv_heads, n_kv_heads > 1, n_heads % n_kv_heads == 0) def mlx_primitives_sdpa_with_gqa(q, k, v, scale, mask=None): n_repeats = q.shape[1] // k.shape[1] # borrowing kv cache tiling from mlx-examples/llms/mistral/mistral.py n_heads = q.shape[1] B = q.shape[0] L = k.shape[2] def repeat(a): a = mx.concatenate([mx.expand_dims(a, 2)] * n_repeats, axis=2) return a.reshape([B, n_heads, L, -1]) k, v = map(repeat, (k, v)) return mlx_primitives_sdpa(q, k, v, scale, mask=mask) class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase): def test_fast_sdpa(self): # Not yet supported: # * K pre-transposed in kernel, V pre-transposed in kernel np.random.seed(0) R = 20 L = R Dk = 64 H = 3 scale = float(1.0 / np.sqrt(Dk)) q_npy = np.random.normal(0.0, 1.0, (1, H, R, Dk)).astype(np.float32) k_npy = np.random.normal(0.0, 1.0, (1, H, L, Dk)).astype(np.float32) v_npy = np.random.normal(0.0, 1.0, (1, H, L, Dk)).astype(np.float32) q_mlx = mx.array(q_npy) k_mlx = mx.array(k_npy) v_mlx = mx.array(v_npy) reference = mlx_primitives_sdpa(q_mlx, k_mlx, v_mlx, scale) o_mlx = mx.fast.scaled_dot_product_attention( q_mlx, k_mlx, v_mlx, scale=scale, mask=None ) self.assertListEqual(list(reference.shape), list(o_mlx.shape)) self.assertTrue(mx.allclose(o_mlx, reference, atol=1e-4)) dtypes = [np.float32] Dk = 64 if self.is_apple_silicon: dtypes.append(np.half) for SEQUENCE_LENGTH in [63, 129, 400]: for DTYPE in dtypes: B = 2 H = 24 n_kv_heads = H q_npy = np.random.normal(0.0, 1.0, (B, H, SEQUENCE_LENGTH, Dk)).astype( DTYPE ) k_npy = np.random.normal( 0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk) ).astype(DTYPE) v_npy = np.random.normal( 0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk) ).astype(DTYPE) q_mlx = mx.array(q_npy) k_mlx = mx.array(k_npy) v_mlx = mx.array(v_npy) reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale) o_mlx = mx.fast.scaled_dot_product_attention( q_mlx, k_mlx, v_mlx, scale=scale, ) self.assertListEqual(list(reference.shape), list(o_mlx.shape)) rtol = 1e-3 atol = 1e-2 if SEQUENCE_LENGTH > 500: rtol = 1e-2 if DTYPE == np.half: rtol = 1e-2 self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol)) class TestFastSDPA(mlx_tests.MLXTestCase): def test_fast_sdpa(self): # Not yet supported: # * K pre-transposed in kernel, V pre-transposed in kernel np.random.seed(0) L = 43 R = 1 Dk = 128 scale = float(1.0 / np.sqrt(128.0)) q_npy = np.random.normal(0.0, 1.0, (1, 32, R, Dk)).astype(np.float32) k_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32) v_npy = np.random.normal(0.0, 1.0, (1, 32, L, Dk)).astype(np.float32) q_mlx = mx.array(q_npy) k_mlx = mx.array(k_npy) v_mlx = mx.array(v_npy) reference = mlx_primitives_sdpa(q_mlx, k_mlx, v_mlx, scale) o_mlx = mx.fast.scaled_dot_product_attention( q_mlx, k_mlx, v_mlx, scale=scale, mask=None ) self.assertListEqual(list(reference.shape), list(o_mlx.shape)) self.assertTrue(mx.allclose(o_mlx, reference, atol=1e-4)) B = 1 H = 32 dtypes = [np.float32] if self.is_apple_silicon: dtypes.append(np.half) for SEQUENCE_LENGTH in [1, 7, 9, 32, 63, 67, 129, 400, 2000]: for DO_GQA in [0, 1]: for DTYPE in dtypes: n_kv_heads = 8 if DO_GQA else 32 q_npy = np.random.normal(0.0, 1.0, (B, H, R, Dk)).astype(DTYPE) k_npy = np.random.normal( 0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk) ).astype(DTYPE) v_npy = np.random.normal( 0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk) ).astype(DTYPE) q_mlx = mx.array(q_npy) k_mlx = mx.array(k_npy) v_mlx = mx.array(v_npy) reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale) o_mlx = mx.fast.scaled_dot_product_attention( q_mlx, k_mlx, v_mlx, scale=scale ) self.assertListEqual(list(reference.shape), list(o_mlx.shape)) rtol = 1e-5 atol = 1e-1 if SEQUENCE_LENGTH > 500: rtol = 1e-2 if DTYPE == np.half: rtol = 1e-2 self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol)) q = mx.random.normal(shape=(1, 32, 1, Dk)) k = mx.random.normal(shape=(1, 32, 32, Dk)) v = mx.random.normal(shape=(1, 32, 128, Dk)) atol = 1e-6 y = mlx_primitives_sdpa(q, k, v[:, :, :32], scale) y_hat = mx.fast.scaled_dot_product_attention(q, k, v[:, :, :32], scale=scale) self.assertTrue(mx.allclose(y, y_hat, atol=atol)) # Test with per-example mask q = mx.random.normal(shape=(2, 8, 4, 32)) k = mx.random.normal(shape=(2, 2, 8, 32)) v = mx.random.normal(shape=(2, 2, 8, 32)) mask = 10 * mx.random.normal(shape=(2, 1, 4, 8)) y = mlx_primitives_sdpa_with_gqa(q, k, v, scale, mask=mask) y_hat = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) self.assertTrue(mx.allclose(y, y_hat, atol=atol)) # Test with boolean causal mask indices = mx.arange(8) bool_mask = indices[:, None] >= indices[None] additive_mask = (~bool_mask).astype(mx.float32) * mx.finfo(mx.float32).min x = mx.random.normal(shape=(1, 2, 8, 32)) y = mlx_primitives_sdpa_with_gqa(x, x, x, scale, mask=additive_mask) y_hat = mx.fast.scaled_dot_product_attention( x, x, x, scale=scale, mask=bool_mask ) self.assertTrue(mx.allclose(y, y_hat, atol=atol)) def test_fast_sdpa_vector_kv_transposed_head_seq(self): D = 64 Nq = 4 Nkv = 1 scale = 1.0 mx.random.seed(0) q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D)) lengths = [43, 4096] for L in lengths: k = 5e-1 * mx.random.normal(shape=(1, L, Nkv, D)) v = 5e-1 * mx.random.normal(shape=(1, L, Nkv, D)) k = k.swapaxes(1, 2) v = v.swapaxes(1, 2) masks = [ mx.array(True), mx.array([True] * (L - 10) + [False] * 10), mx.random.uniform(shape=(Nq, 1, L)) > 0.2, mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, ] for m in masks: ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) out = mx.fast.scaled_dot_product_attention( q, k, v, scale=scale, mask=m, ) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) def test_fast_sdpa_vector(self): D = 64 L = 43 Nq = 4 Nkv = 1 scale = 1.0 mx.random.seed(0) q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D)) k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) with self.assertRaises(ValueError): mx.fast.scaled_dot_product_attention( q, k, v, scale=scale, mask=mx.full((Nq, 2, L), False), ) masks = [ None, mx.array(True), mx.array([True] * (L - 10) + [False] * 10), mx.random.uniform(shape=(Nq, 1, L)) > 0.2, mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, mx.random.uniform(shape=(Nq, 1, L)), mx.random.uniform(shape=(L, 1, Nq)).T, mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2), mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2), "causal", ] for m in masks: ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) out = mx.fast.scaled_dot_product_attention( q, k, v, scale=scale, mask=m, ) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) L = 4096 scale = 1.0 mx.random.seed(0) q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D)) k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) masks = [ mx.array(True), mx.array([True] * (L - 10) + [False] * 10), mx.random.uniform(shape=(Nq, 1, L)) > 0.2, mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, mx.random.uniform(shape=(Nq, 1, L)), mx.random.uniform(shape=(L, 1, Nq)).T, mx.log(mx.random.uniform(shape=(Nq, 1, L)) > 0.2), mx.log(mx.random.uniform(shape=(L, 1, Nq)).T > 0.2), "causal", ] for m in masks: ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) out = mx.fast.scaled_dot_product_attention( q, k, v, scale=scale, mask=m, ) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) def test_fast_sdpa_few_query(self): D = 64 L = 43 Lq = 8 Nq = 8 Nkv = 1 scale = 1.0 mx.random.seed(0) q = 5e-1 * mx.random.normal(shape=(1, Lq, Nq, D)) q = q.swapaxes(1, 2) k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) masks = [ None, mx.array(True), mx.array([True] * (L - 10) + [False] * 10), mx.random.uniform(shape=(Nq, 1, L)) > 0.2, mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, "causal", ] for m in masks: ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) out = mx.fast.scaled_dot_product_attention( q, k, v, scale=scale, mask=m, ) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) L = 4096 scale = 1.0 mx.random.seed(0) q = 5e-1 * mx.random.normal(shape=(1, Nq, Lq, D)) k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) masks = [ None, mx.array(True), mx.array([True] * (L - 10) + [False] * 10), mx.random.uniform(shape=(Nq, 1, L)) > 0.2, mx.random.uniform(shape=(L, 1, Nq)).T > 0.2, "causal", ] for m in masks: ref = mlx_primitives_sdpa(q, k, v, scale, mask=m) out = mx.fast.scaled_dot_product_attention( q, k, v, scale=scale, mask=m, ) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) @unittest.skip("Different head and value dims is not enabled") def test_fast_sdpa_vector_value_dims(self): D = 192 V = 128 Nq = 4 Nkv = 1 scale = 1.0 mx.random.seed(0) for L in [43, 128, 237, 8192]: q = 5e-1 * mx.random.normal(shape=(1, Nq, 1, D)) k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, V)) ref = mlx_primitives_sdpa(q, k, v, scale) out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) def test_sdpa_vector_batched(self): D = 64 q = mx.random.normal(shape=(2, 1, 3, D)) k = mx.random.normal(shape=(2, 1, 3, D)) v = mx.random.normal(shape=(2, 1, 3, D)) out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) ref = mlx_ref_attn(q, k, v) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) q = mx.random.normal(shape=(2, 4, 3, D)) out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) ref = mlx_ref_attn(q, k, v) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) q = mx.random.normal(shape=(2, 3, 4, D)).swapaxes(1, 2) out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) ref = mlx_ref_attn(q, k, v) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) k = mx.random.normal(shape=(2, 3, 1, D)).swapaxes(1, 2) out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) ref = mlx_ref_attn(q, k, v) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) q = mx.random.normal(shape=(2, 4, 3, D)) k = mx.random.normal(shape=(2, 3, 2, D)).swapaxes(1, 2) v = mx.random.normal(shape=(2, 2, 3, D)) out = mx.fast.scaled_dot_product_attention(q, k, v, mask=None, scale=1.0) ref = mlx_ref_attn(q, k, v) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) q = mx.random.normal(shape=(2, 4, 3, D)) k = mx.random.normal(shape=(2, 1, 3, D)) v = mx.random.normal(shape=(2, 1, 3, D)) mask = 10 * mx.random.normal(shape=(1, 2, 3, 3)).swapaxes(0, 1) out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0) ref = mlx_ref_attn(q, k, v, mask=mask) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) class TestSDPA(mlx_tests.MLXTestCase): @property def dtypes(self): return ["float32", "float16"] if mx.metal.is_available() else ["float32"] def test_sdpa(self): if not mx.metal.is_available(): return # fmt: off shapes_64 = ( # ( B, qsl, ksl, head_dim, n_qh, n_kvh) ( 1, 128, 128, 64, 32, 32), ( 1, 64, 128, 64, 32, 32), ( 1, 65, 128, 64, 32, 8), ( 1, 64, 127, 64, 32, 8), ( 1, 65, 127, 64, 32, 8), ( 1, 127, 65, 64, 32, 8), ) shapes_128 = ( # ( B, qsl, ksl, head_dim, n_qh, n_kvh) ( 1, 128, 128, 128, 32, 8), ( 1, 64, 128, 128, 32, 8), ( 1, 65, 127, 128, 32, 8), ( 1, 127, 65, 128, 32, 8), ) # fmt: on shapes = shapes_64 + shapes_128 masks = [None, "additive", "bool", "causal"] transposes = (False, True) for dtype in self.dtypes: for t in transposes: for mask_str in masks: for B, qL, kL, D, qH, kH in shapes: with self.subTest( B=B, qsl=qL, ksl=kL, head_dim=D, n_q_heads=qH, n_kv_heads=kH, mask=mask_str, transpose=t, dtype=dtype, ): np.random.seed(0) q_mx, k_mx, v_mx, scale, mask = prepare_inputs( B, qL, kL, D, qH, kH, mask_str, t, dtype ) out_ref = do_attention( mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, t ) out_fst = do_attention( mx.fast.scaled_dot_product_attention, q_mx, k_mx, v_mx, scale, mask, t, ) atol = 2e-5 if dtype == "float32" else 3e-4 self.assertListEqual( list(out_ref.shape), list(out_fst.shape) ) diff = mx.abs(out_fst - out_ref) - atol * mx.abs(out_ref) self.assertLessEqual(mx.max(diff).item(), atol) def test_sdpa_broadcast_mask(self): mask = mx.array(True) D = 64 Nq = 4 Nkv = 1 scale = 1.0 L = 256 mx.random.seed(0) q = 5e-1 * mx.random.normal(shape=(1, Nq, L, D)) k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) ref = mlx_primitives_sdpa(q, k, v, scale, mask=mask) out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) def test_sdpa_prommote_mask(self): mask = mx.array(2.0, mx.bfloat16) D = 64 Nq = 4 Nkv = 1 scale = 1.0 L = 256 mx.random.seed(0) q = 5e-1 * mx.random.normal(shape=(1, Nq, L, D)) k = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) v = 5e-1 * mx.random.normal(shape=(1, Nkv, L, D)) ref = mlx_primitives_sdpa(q, k, v, scale, mask=mask) out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) def test_sdpa_nan_bug(self): N = 128 q_shape = (1, 1, N, 128) kv_shape = (1, 1, N, 128) q = mx.random.uniform(shape=q_shape) k = mx.random.uniform(shape=kv_shape) v = mx.random.uniform(shape=kv_shape) # Make boolean window causal mask linds = rinds = mx.arange(N) linds = linds[:, None] rinds = rinds[None] mask = linds >= rinds mask = mask & (linds <= rinds + 111) out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0) expected = mlx_ref_attn(q, k, v, mask=mask, scale=1.0) self.assertFalse(mx.isnan(out).any().item()) self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4) # And an additive one mask = mx.log(mask) out = mx.fast.scaled_dot_product_attention(q, k, v, mask=mask, scale=1.0) expected = mlx_ref_attn(q, k, v, mask=mask, scale=1.0) self.assertFalse(mx.isnan(out).any().item()) self.assertLessEqual(mx.abs(out - expected).max().item(), 1e-4) if __name__ == "__main__": unittest.main(failfast=True)