mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Support fused masking in Attention (#1924)
* Update API to allow mask='causal' in fast::sdpa * Add fallback * Update steel::AttnParams * Fix typo * WIP, basic causal * Update tests * Update benchmarking * Update masking loop limits * Add bool masking and update tests * Update additive mask * Update benchmarks * Update benchmarks * Update tests * Update for bfloat error * Update early exit * Add random seed to tests
This commit is contained in:
		@@ -28,11 +28,34 @@ def bench(f, *args):
 | 
				
			|||||||
    return (e - s) * 1e-9
 | 
					    return (e - s) * 1e-9
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def mlx_sdpa_fused_inner(q, k, v, scale):
 | 
					def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
 | 
				
			||||||
    return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
 | 
					    np_dtype = getattr(np, dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
 | 
				
			||||||
 | 
					    shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    scale = 1.0 / math.sqrt(D)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
 | 
				
			||||||
 | 
					    k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
 | 
				
			||||||
 | 
					    v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    q_mx = mx.array(q_np)
 | 
				
			||||||
 | 
					    k_mx = mx.array(k_np)
 | 
				
			||||||
 | 
					    v_mx = mx.array(v_np)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if mask is not None:
 | 
				
			||||||
 | 
					        if mask == "additive":
 | 
				
			||||||
 | 
					            mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
 | 
				
			||||||
 | 
					            mask = mx.array(mask_np)
 | 
				
			||||||
 | 
					        elif mask == "bool":
 | 
				
			||||||
 | 
					            mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
 | 
				
			||||||
 | 
					            mask = mx.array(mask_np)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return q_mx, k_mx, v_mx, scale, mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
 | 
					def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
 | 
				
			||||||
    q_dtype = q.dtype
 | 
					    q_dtype = q.dtype
 | 
				
			||||||
    q = q * mx.array(scale, q_dtype)
 | 
					    q = q * mx.array(scale, q_dtype)
 | 
				
			||||||
    n_q_heads = q.shape[-3]
 | 
					    n_q_heads = q.shape[-3]
 | 
				
			||||||
@@ -41,6 +64,7 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    B = q.shape[0]
 | 
					    B = q.shape[0]
 | 
				
			||||||
    L = q.shape[2]
 | 
					    L = q.shape[2]
 | 
				
			||||||
 | 
					    kL = k.shape[2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if n_repeats > 1:
 | 
					    if n_repeats > 1:
 | 
				
			||||||
        q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
 | 
					        q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
 | 
				
			||||||
@@ -48,10 +72,27 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
 | 
				
			|||||||
        v = mx.expand_dims(v, 2)
 | 
					        v = mx.expand_dims(v, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    scores = q @ mx.swapaxes(k, -1, -2)
 | 
					    scores = q @ mx.swapaxes(k, -1, -2)
 | 
				
			||||||
    if f32softmax:
 | 
					
 | 
				
			||||||
        scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
 | 
					    if mask is not None:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if mask == "causal":
 | 
				
			||||||
 | 
					            q_offset = max(0, kL - L)
 | 
				
			||||||
 | 
					            q_indices = mx.arange(q_offset, q_offset + L)
 | 
				
			||||||
 | 
					            k_indices = mx.arange(kL)
 | 
				
			||||||
 | 
					            mask = q_indices[:, None] >= k_indices[None]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if n_repeats > 1 and mask.ndim >= 3:
 | 
				
			||||||
 | 
					            if mask.shape[-3] == 1:
 | 
				
			||||||
 | 
					                mask = mx.expand_dims(mask, -3)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
        scores = mx.softmax(scores, axis=-1)
 | 
					                mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if mask.dtype == mx.bool_:
 | 
				
			||||||
 | 
					            scores = mx.where(mask, scores, -np.float32(np.inf))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            scores += mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    scores = mx.softmax(scores, axis=-1, precise=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    out = scores @ v
 | 
					    out = scores @ v
 | 
				
			||||||
    if n_repeats > 1:
 | 
					    if n_repeats > 1:
 | 
				
			||||||
@@ -60,74 +101,55 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
 | 
				
			|||||||
    return out
 | 
					    return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def mlx_spda_unfused(q, k, v, scale, transpose):
 | 
					def mlx_fused_attn(q, k, v, scale, mask):
 | 
				
			||||||
    q_out = q
 | 
					    return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def do_attention(f, q, k, v, scale, mask=None, transpose=False):
 | 
				
			||||||
    if transpose:
 | 
					    if transpose:
 | 
				
			||||||
        k = mx.transpose(k, (0, 2, 1, 3))
 | 
					        q_t = mx.transpose(q, (0, 2, 1, 3))
 | 
				
			||||||
        v = mx.transpose(v, (0, 2, 1, 3))
 | 
					        k_t = mx.transpose(k, (0, 2, 1, 3))
 | 
				
			||||||
 | 
					        v_t = mx.transpose(v, (0, 2, 1, 3))
 | 
				
			||||||
 | 
					        o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
 | 
				
			||||||
 | 
					        return mx.transpose(o_t, (0, 2, 1, 3))
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return f(q, k, v, scale=scale, mask=mask)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
 | 
				
			||||||
 | 
					    q_out = q
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for i in range(N_iter_func):
 | 
					    for i in range(N_iter_func):
 | 
				
			||||||
        if transpose:
 | 
					        q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
 | 
				
			||||||
            q_out = mx.transpose(q_out, (0, 2, 1, 3))
 | 
					 | 
				
			||||||
        q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
 | 
					 | 
				
			||||||
        if transpose:
 | 
					 | 
				
			||||||
            q_out = mx.transpose(q_out, (0, 2, 1, 3))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    mx.eval(q_out)
 | 
					    mx.eval(q_out)
 | 
				
			||||||
    return q_out
 | 
					    return q_out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def mlx_spda_fused(q, k, v, scale, transpose):
 | 
					def bench_shape(
 | 
				
			||||||
    q_out = q
 | 
					    B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
 | 
				
			||||||
    if transpose:
 | 
					):
 | 
				
			||||||
        k = mx.transpose(k, (0, 2, 1, 3))
 | 
					    q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
 | 
				
			||||||
        v = mx.transpose(v, (0, 2, 1, 3))
 | 
					        B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
 | 
				
			||||||
 | 
					 | 
				
			||||||
    for i in range(N_iter_func):
 | 
					 | 
				
			||||||
        if transpose:
 | 
					 | 
				
			||||||
            q_out = mx.transpose(q_out, (0, 2, 1, 3))
 | 
					 | 
				
			||||||
        q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
 | 
					 | 
				
			||||||
        if transpose:
 | 
					 | 
				
			||||||
            q_out = mx.transpose(q_out, (0, 2, 1, 3))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    mx.eval(q_out)
 | 
					 | 
				
			||||||
    return q_out
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
 | 
					 | 
				
			||||||
    shape_q = (
 | 
					 | 
				
			||||||
        (B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    shape_kv = (
 | 
					 | 
				
			||||||
        (B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
 | 
					 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
 | 
					    time_mlx_unfused = bench(
 | 
				
			||||||
    k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
 | 
					        do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
 | 
				
			||||||
    v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
 | 
					    )
 | 
				
			||||||
 | 
					    time_mlx_fused = bench(
 | 
				
			||||||
 | 
					        do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    scale = math.sqrt(1.0 / head_dim)
 | 
					    o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
 | 
				
			||||||
 | 
					    o_mlx_unfused = do_attention(
 | 
				
			||||||
 | 
					        mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    q_mx = mx.array(q_np)
 | 
					    atol = 1e-5 if dtype == "float32" else 2e-4
 | 
				
			||||||
    k_mx = mx.array(k_np)
 | 
					 | 
				
			||||||
    v_mx = mx.array(v_np)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
 | 
					    if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
 | 
				
			||||||
    time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if transpose:
 | 
					 | 
				
			||||||
        q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
 | 
					 | 
				
			||||||
        k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
 | 
					 | 
				
			||||||
        v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
 | 
					 | 
				
			||||||
    o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    atol = 1e-5 if np_dtype == np.float32 else 1e-4
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
 | 
					 | 
				
			||||||
        print(
 | 
					        print(
 | 
				
			||||||
            f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
 | 
					            f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return time_mlx_fused, time_mlx_unfused
 | 
					    return time_mlx_fused, time_mlx_unfused
 | 
				
			||||||
@@ -151,39 +173,51 @@ if __name__ == "__main__":
 | 
				
			|||||||
          (  1,   128,   128,       64,   32,    32),
 | 
					          (  1,   128,   128,       64,   32,    32),
 | 
				
			||||||
          (  1,   256,   256,       64,   32,    32),
 | 
					          (  1,   256,   256,       64,   32,    32),
 | 
				
			||||||
          (  1,   512,   512,       64,   32,    32),
 | 
					          (  1,   512,   512,       64,   32,    32),
 | 
				
			||||||
          (  1,  1024,  1024,       64,   32,    32),
 | 
					          (  1,  1024,  1024,       64,   32,     8),
 | 
				
			||||||
          (  1,  2048,  2048,       64,   32,    32),
 | 
					          (  1,  2048,  2048,       64,   32,     8),
 | 
				
			||||||
          (  1,  4096,  4096,       64,   32,    32),
 | 
					          (  1,  4096,  4096,       64,   32,     8),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    shapes_80 = (
 | 
					    shapes_80 = (
 | 
				
			||||||
        # (  B,   qsl,   ksl, head_dim, n_qh, n_kvh)
 | 
					        # (  B,   qsl,   ksl, head_dim, n_qh, n_kvh)
 | 
				
			||||||
          (  1,  1024,  1024,       80,   32,    32),
 | 
					          (  1,  1024,  1024,       80,   32,     8),
 | 
				
			||||||
          (  1,  2048,  2048,       80,   32,    32),
 | 
					          (  1,  2048,  2048,       80,   32,     8),
 | 
				
			||||||
          (  1,  4096,  4096,       80,   32,    32),
 | 
					          (  1,  4096,  4096,       80,   32,     8),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    shapes_128 = (
 | 
					    shapes_128 = (
 | 
				
			||||||
        # (  B,   qsl,   ksl, head_dim, n_qh, n_kvh)
 | 
					        # (  B,   qsl,   ksl, head_dim, n_qh, n_kvh)
 | 
				
			||||||
          (  1,  1024,  1024,      128,   32,    32),
 | 
					          (  1,  1024,  1024,      128,   32,     8),
 | 
				
			||||||
          (  1,  2048,  2048,      128,   32,    32),
 | 
					          (  1,  2048,  2048,      128,   32,     8),
 | 
				
			||||||
          (  1,  4096,  4096,      128,   32,    32),
 | 
					          (  1,  4096,  4096,      128,   32,     8),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    # fmt: on
 | 
					    # fmt: on
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    shapes = shapes_64 + shapes_80 + shapes_128
 | 
					    shapes = shapes_64 + shapes_80 + shapes_128
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print("  B,   qsl,   ksl, hdim, n_qh, n_kvh, tpose,   dtype, t_unfs, t_fuse, diff%")
 | 
					    masks = [None, "bool", "causal"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print(
 | 
				
			||||||
 | 
					        "  B,   qsl,   ksl, hdim, n_qh, n_kvh, t,   dtype,     mask, t_unfs, t_fuse, diff%"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for dtype in dtypes:
 | 
					    for dtype in dtypes:
 | 
				
			||||||
        for transpose in transposes:
 | 
					        for transpose in transposes:
 | 
				
			||||||
            for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
 | 
					            for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
 | 
				
			||||||
                np_dtype = getattr(np, dtype)
 | 
					                for mask_in in masks:
 | 
				
			||||||
                    time_mlx_fused, time_mlx_unfused = bench_shape(
 | 
					                    time_mlx_fused, time_mlx_unfused = bench_shape(
 | 
				
			||||||
                    B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
 | 
					                        B,
 | 
				
			||||||
 | 
					                        qsl,
 | 
				
			||||||
 | 
					                        ksl,
 | 
				
			||||||
 | 
					                        head_dim,
 | 
				
			||||||
 | 
					                        n_q_heads,
 | 
				
			||||||
 | 
					                        n_kv_heads,
 | 
				
			||||||
 | 
					                        dtype,
 | 
				
			||||||
 | 
					                        transpose,
 | 
				
			||||||
 | 
					                        mask_in,
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                    diff = time_mlx_unfused / time_mlx_fused - 1.0
 | 
					                    diff = time_mlx_unfused / time_mlx_fused - 1.0
 | 
				
			||||||
                    t_str = 1 if transpose else 0
 | 
					                    t_str = 1 if transpose else 0
 | 
				
			||||||
                    print(
 | 
					                    print(
 | 
				
			||||||
                    f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
 | 
					                        f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
// Copyright © 2024 Apple Inc.
 | 
					// Copyright © 2024-25 Apple Inc.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
using namespace mlx::steel;
 | 
					using namespace mlx::steel;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -9,6 +9,9 @@ using namespace mlx::steel;
 | 
				
			|||||||
constant bool align_Q [[function_constant(200)]];
 | 
					constant bool align_Q [[function_constant(200)]];
 | 
				
			||||||
constant bool align_K [[function_constant(201)]];
 | 
					constant bool align_K [[function_constant(201)]];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					constant bool has_mask [[function_constant(300)]];
 | 
				
			||||||
 | 
					constant bool do_causal [[function_constant(301)]];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename T>
 | 
					template <typename T>
 | 
				
			||||||
struct TransformScale {
 | 
					struct TransformScale {
 | 
				
			||||||
  T scale;
 | 
					  T scale;
 | 
				
			||||||
@@ -69,6 +72,7 @@ template <
 | 
				
			|||||||
    int BD,
 | 
					    int BD,
 | 
				
			||||||
    int WM,
 | 
					    int WM,
 | 
				
			||||||
    int WN,
 | 
					    int WN,
 | 
				
			||||||
 | 
					    typename MaskType = float,
 | 
				
			||||||
    typename AccumType = float>
 | 
					    typename AccumType = float>
 | 
				
			||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
 | 
					[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
 | 
				
			||||||
    const device T* Q [[buffer(0)]],
 | 
					    const device T* Q [[buffer(0)]],
 | 
				
			||||||
@@ -76,6 +80,8 @@ template <
 | 
				
			|||||||
    const device T* V [[buffer(2)]],
 | 
					    const device T* V [[buffer(2)]],
 | 
				
			||||||
    device T* O [[buffer(3)]],
 | 
					    device T* O [[buffer(3)]],
 | 
				
			||||||
    const constant AttnParams* params [[buffer(4)]],
 | 
					    const constant AttnParams* params [[buffer(4)]],
 | 
				
			||||||
 | 
					    const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
 | 
				
			||||||
 | 
					    const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
 | 
				
			||||||
    uint simd_lane_id [[thread_index_in_simdgroup]],
 | 
					    uint simd_lane_id [[thread_index_in_simdgroup]],
 | 
				
			||||||
    uint simd_group_id [[simdgroup_index_in_threadgroup]],
 | 
					    uint simd_group_id [[simdgroup_index_in_threadgroup]],
 | 
				
			||||||
    uint3 tid [[threadgroup_position_in_grid]],
 | 
					    uint3 tid [[threadgroup_position_in_grid]],
 | 
				
			||||||
@@ -102,6 +108,11 @@ template <
 | 
				
			|||||||
      tidl.y * params->O_strides[1] + // Head
 | 
					      tidl.y * params->O_strides[1] + // Head
 | 
				
			||||||
      tidl.x * BQ * params->O_strides[2]; // Seqeunce
 | 
					      tidl.x * BQ * params->O_strides[2]; // Seqeunce
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (has_mask) {
 | 
				
			||||||
 | 
					    mask += tidl.z * mask_params->M_strides[0] + // Batch
 | 
				
			||||||
 | 
					        tidl.y * mask_params->M_strides[1]; // Head
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Prepare threadgroup memory
 | 
					  // Prepare threadgroup memory
 | 
				
			||||||
  constexpr short padQ = 16 / sizeof(T);
 | 
					  constexpr short padQ = 16 / sizeof(T);
 | 
				
			||||||
  constexpr short padK = 16 / sizeof(T);
 | 
					  constexpr short padK = 16 / sizeof(T);
 | 
				
			||||||
@@ -203,7 +214,7 @@ template <
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  // Load Q blocks apply scale
 | 
					  // Load Q blocks apply scale
 | 
				
			||||||
  if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
 | 
					  if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
 | 
				
			||||||
    loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ));
 | 
					    loader_q.load_safe(short2(BD, params->qL_rem));
 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
    loader_q.load_unsafe();
 | 
					    loader_q.load_unsafe();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@@ -221,12 +232,19 @@ template <
 | 
				
			|||||||
    max_score[i] = Limits<AccumType>::min;
 | 
					    max_score[i] = Limits<AccumType>::min;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  int kb_lim = params->NK;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (do_causal) {
 | 
				
			||||||
 | 
					    int q_max = (tid.x + 1) * BQ + params->qL_off;
 | 
				
			||||||
 | 
					    kb_lim = (q_max + BK - 1) / BK;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Loop over KV seq length
 | 
					  // Loop over KV seq length
 | 
				
			||||||
  for (int kb = 0; kb < params->NK; kb++) {
 | 
					  for (int kb = 0; kb < kb_lim; kb++) {
 | 
				
			||||||
    // Load K block and apply scale
 | 
					    // Load K block and apply scale
 | 
				
			||||||
    threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
					    threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
				
			||||||
    if (!align_K && kb == (params->NK_aligned)) {
 | 
					    if (!align_K && kb == (params->NK_aligned)) {
 | 
				
			||||||
      loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
 | 
					      loader_k.load_safe(short2(BD, params->kL_rem));
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      loader_k.load_unsafe();
 | 
					      loader_k.load_unsafe();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@@ -250,12 +268,11 @@ template <
 | 
				
			|||||||
      tile_matmad(Stile, Qtile, Ktile, Stile);
 | 
					      tile_matmad(Stile, Qtile, Ktile, Stile);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Mask out of length sequence
 | 
					    // Mask out length sequence
 | 
				
			||||||
    if (!align_K && kb == (params->NK_aligned)) {
 | 
					    if (!align_K && kb == (params->NK_aligned)) {
 | 
				
			||||||
      using stile_t = decltype(Stile);
 | 
					      using stile_t = decltype(Stile);
 | 
				
			||||||
      using selem_t = typename stile_t::elem_type;
 | 
					      using selem_t = typename stile_t::elem_type;
 | 
				
			||||||
      constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
 | 
					      constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
 | 
				
			||||||
      const short lim = params->kL - params->NK_aligned * BK;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
      STEEL_PRAGMA_UNROLL
 | 
					      STEEL_PRAGMA_UNROLL
 | 
				
			||||||
      for (short i = 0; i < stile_t::kTileRows; i++) {
 | 
					      for (short i = 0; i < stile_t::kTileRows; i++) {
 | 
				
			||||||
@@ -264,7 +281,7 @@ template <
 | 
				
			|||||||
          short col_pos = sn + (j * stile_t::kFragCols);
 | 
					          short col_pos = sn + (j * stile_t::kFragCols);
 | 
				
			||||||
          STEEL_PRAGMA_UNROLL
 | 
					          STEEL_PRAGMA_UNROLL
 | 
				
			||||||
          for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
 | 
					          for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
 | 
				
			||||||
            if ((col_pos + jj) >= lim) {
 | 
					            if ((col_pos + jj) >= params->kL_rem) {
 | 
				
			||||||
              Stile.frag_at(i, j)[jj] = neg_inf;
 | 
					              Stile.frag_at(i, j)[jj] = neg_inf;
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
@@ -272,11 +289,78 @@ template <
 | 
				
			|||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Mask out if causal
 | 
				
			||||||
 | 
					    if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) {
 | 
				
			||||||
 | 
					      using stile_t = decltype(Stile);
 | 
				
			||||||
 | 
					      using selem_t = typename stile_t::elem_type;
 | 
				
			||||||
 | 
					      constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      STEEL_PRAGMA_UNROLL
 | 
				
			||||||
 | 
					      for (short i = 0; i < stile_t::kTileRows; i++) {
 | 
				
			||||||
 | 
					        const int row_pos =
 | 
				
			||||||
 | 
					            tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows);
 | 
				
			||||||
 | 
					        STEEL_PRAGMA_UNROLL
 | 
				
			||||||
 | 
					        for (short j = 0; j < stile_t::kTileCols; j++) {
 | 
				
			||||||
 | 
					          const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
 | 
				
			||||||
 | 
					          STEEL_PRAGMA_UNROLL
 | 
				
			||||||
 | 
					          for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
 | 
				
			||||||
 | 
					            if (row_pos < (col_pos + jj)) {
 | 
				
			||||||
 | 
					              Stile.frag_at(i, j)[jj] = neg_inf;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Other masking as needed
 | 
				
			||||||
 | 
					    if (has_mask) {
 | 
				
			||||||
 | 
					      using stile_t = decltype(Stile);
 | 
				
			||||||
 | 
					      using selem_t = typename stile_t::elem_type;
 | 
				
			||||||
 | 
					      constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      constexpr bool is_bool = is_same_v<MaskType, bool>;
 | 
				
			||||||
 | 
					      using melem_t = typename metal::conditional_t<is_bool, bool, selem_t>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      using MMAFrag_mask_t = BaseMMAFrag<melem_t, kFragSize, kFragSize>;
 | 
				
			||||||
 | 
					      using frag_t = typename MMAFrag_mask_t::frag_type;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      STEEL_PRAGMA_UNROLL
 | 
				
			||||||
 | 
					      for (short i = 0; i < stile_t::kTileRows; i++) {
 | 
				
			||||||
 | 
					        const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows);
 | 
				
			||||||
 | 
					        STEEL_PRAGMA_UNROLL
 | 
				
			||||||
 | 
					        for (short j = 0; j < stile_t::kTileCols; j++) {
 | 
				
			||||||
 | 
					          const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          frag_t mfrag;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          MMAFrag_mask_t::load_safe(
 | 
				
			||||||
 | 
					              mfrag,
 | 
				
			||||||
 | 
					              mask,
 | 
				
			||||||
 | 
					              int(mask_params->M_strides[2]),
 | 
				
			||||||
 | 
					              Int<1>{},
 | 
				
			||||||
 | 
					              params->qL,
 | 
				
			||||||
 | 
					              params->kL,
 | 
				
			||||||
 | 
					              row_pos,
 | 
				
			||||||
 | 
					              col_pos);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          STEEL_PRAGMA_UNROLL
 | 
				
			||||||
 | 
					          for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) {
 | 
				
			||||||
 | 
					            if constexpr (is_bool) {
 | 
				
			||||||
 | 
					              Stile.frag_at(i, j)[jj] =
 | 
				
			||||||
 | 
					                  mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf;
 | 
				
			||||||
 | 
					            } else {
 | 
				
			||||||
 | 
					              Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
					    threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Load V blocks
 | 
					    // Load V blocks
 | 
				
			||||||
    if (!align_K && kb == (params->NK_aligned)) {
 | 
					    if (!align_K && kb == (params->NK_aligned)) {
 | 
				
			||||||
      loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
 | 
					      loader_v.load_safe(short2(BD, params->kL_rem));
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      loader_v.load_unsafe();
 | 
					      loader_v.load_unsafe();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@@ -367,8 +451,7 @@ template <
 | 
				
			|||||||
  O += (tm + sm) * params->O_strides[2] + sn;
 | 
					  O += (tm + sm) * params->O_strides[2] + sn;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
 | 
					  if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
 | 
				
			||||||
    auto dst_tile_dims =
 | 
					    auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm));
 | 
				
			||||||
        short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm));
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
 | 
					    if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
 | 
				
			||||||
      return;
 | 
					      return;
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
// Copyright © 2024 Apple Inc.
 | 
					// Copyright © 2024-25 Apple Inc.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// clang-format off
 | 
					// clang-format off
 | 
				
			||||||
#include "mlx/backend/metal/kernels/utils.h"
 | 
					#include "mlx/backend/metal/kernels/utils.h"
 | 
				
			||||||
@@ -6,26 +6,23 @@
 | 
				
			|||||||
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
 | 
					#include "mlx/backend/metal/kernels/steel/attn/attn.h"
 | 
				
			||||||
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h"
 | 
					#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \
 | 
					#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \
 | 
				
			||||||
  template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \
 | 
					  instantiate_kernel(                                                    \
 | 
				
			||||||
  [[kernel]] void attention<dtype, bq, bk, bd, wm, wn, float>( \
 | 
					      "steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd            \
 | 
				
			||||||
      const device dtype* Q [[buffer(0)]], \
 | 
					      "_wm" #wm "_wn" #wn "_mask" #mname,                                \
 | 
				
			||||||
      const device dtype* K [[buffer(1)]], \
 | 
					  attention, dtype, bq, bk, bd, wm, wn, mtype, float)
 | 
				
			||||||
      const device dtype* V [[buffer(2)]], \
 | 
					 | 
				
			||||||
      device dtype* O [[buffer(3)]],\
 | 
					 | 
				
			||||||
      const constant AttnParams* params [[buffer(4)]], \
 | 
					 | 
				
			||||||
      uint simd_lane_id [[thread_index_in_simdgroup]], \
 | 
					 | 
				
			||||||
      uint simd_group_id [[simdgroup_index_in_threadgroup]], \
 | 
					 | 
				
			||||||
      uint3 tid [[threadgroup_position_in_grid]], \
 | 
					 | 
				
			||||||
      uint3 lid [[thread_position_in_threadgroup]]);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
#define instantiate_attn_shapes_helper(iname, itype) \
 | 
					#define instantiate_attn_shapes_helper(iname, itype, mname, mtype)  \
 | 
				
			||||||
    instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \
 | 
					    instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \
 | 
				
			||||||
    instantiate_attn(iname, itype, 32, 32,  80, 4, 1) \
 | 
					    instantiate_attn(iname, itype, 32, 32,  80, 4, 1, mname, mtype) \
 | 
				
			||||||
    instantiate_attn(iname, itype, 32, 32,  64, 4, 1)
 | 
					    instantiate_attn(iname, itype, 32, 32,  64, 4, 1, mname, mtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
instantiate_attn_shapes_helper(float16, half);
 | 
					#define instantiate_attn_mask_helper(iname, itype) \
 | 
				
			||||||
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
 | 
					    instantiate_attn_shapes_helper(iname, itype, iname, itype) \
 | 
				
			||||||
 | 
					    instantiate_attn_shapes_helper(iname, itype, bool_, bool)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
instantiate_attn_shapes_helper(float32, float);
 | 
					instantiate_attn_mask_helper(float16, half);
 | 
				
			||||||
 | 
					instantiate_attn_mask_helper(bfloat16, bfloat16_t);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					instantiate_attn_mask_helper(float32, float);
 | 
				
			||||||
// clang-format on
 | 
					// clang-format on
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -111,7 +111,7 @@ struct BaseMMAFrag<T, 8, 8> {
 | 
				
			|||||||
      for (short j = 0; j < kElemCols; j++) {
 | 
					      for (short j = 0; j < kElemCols; j++) {
 | 
				
			||||||
        if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
 | 
					        if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
 | 
				
			||||||
          dst[i * kElemCols + j] =
 | 
					          dst[i * kElemCols + j] =
 | 
				
			||||||
              static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
 | 
					              static_cast<T>(src[(off_x + i) * str_x + (off_y + j) * str_y]);
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
          dst[i * kElemCols + j] = T(0);
 | 
					          dst[i * kElemCols + j] = T(0);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -26,11 +26,19 @@ struct AttnParams {
 | 
				
			|||||||
  int NQ_aligned; ///< Number of full query blocks
 | 
					  int NQ_aligned; ///< Number of full query blocks
 | 
				
			||||||
  int NK_aligned; ///< Number of full key/value blocks
 | 
					  int NK_aligned; ///< Number of full key/value blocks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  int qL_rem; ///< Remainder in last query block
 | 
				
			||||||
 | 
					  int kL_rem; ///< Remainder in last key/value block
 | 
				
			||||||
 | 
					  int qL_off; ///< Offset in query sequence start
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  int64_t Q_strides[3]; ///< Query  strides (B, H, L, D = 1)
 | 
					  int64_t Q_strides[3]; ///< Query  strides (B, H, L, D = 1)
 | 
				
			||||||
  int64_t K_strides[3]; ///< Key    strides (B, H, L, D = 1)
 | 
					  int64_t K_strides[3]; ///< Key    strides (B, H, L, D = 1)
 | 
				
			||||||
  int64_t V_strides[3]; ///< Value  strides (B, H, L, D = 1)
 | 
					  int64_t V_strides[3]; ///< Value  strides (B, H, L, D = 1)
 | 
				
			||||||
  int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
 | 
					  int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct AttnMaskParams {
 | 
				
			||||||
 | 
					  int64_t M_strides[3]; ///< Mask  strides (B, H, qL, kL = 1)
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace steel
 | 
					} // namespace steel
 | 
				
			||||||
} // namespace mlx
 | 
					} // namespace mlx
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -21,7 +21,9 @@ void sdpa_full_self_attention_metal(
 | 
				
			|||||||
    const array& k,
 | 
					    const array& k,
 | 
				
			||||||
    const array& v,
 | 
					    const array& v,
 | 
				
			||||||
    const float scale,
 | 
					    const float scale,
 | 
				
			||||||
    array& o) {
 | 
					    array& o,
 | 
				
			||||||
 | 
					    bool do_causal_ = false,
 | 
				
			||||||
 | 
					    const std::optional<array>& mask = std::nullopt) {
 | 
				
			||||||
  using namespace mlx::steel;
 | 
					  using namespace mlx::steel;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  int wm = 4;
 | 
					  int wm = 4;
 | 
				
			||||||
@@ -41,11 +43,14 @@ void sdpa_full_self_attention_metal(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  const bool align_Q = (qL % bq) == 0;
 | 
					  const bool align_Q = (qL % bq) == 0;
 | 
				
			||||||
  const bool align_K = (kL % bk) == 0;
 | 
					  const bool align_K = (kL % bk) == 0;
 | 
				
			||||||
 | 
					  const bool has_mask = !!mask;
 | 
				
			||||||
 | 
					  const bool do_causal = do_causal_;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  metal::MTLFCList func_consts = {
 | 
					  metal::MTLFCList func_consts = {
 | 
				
			||||||
      {&align_Q, MTL::DataType::DataTypeBool, 200},
 | 
					      {&align_Q, MTL::DataType::DataTypeBool, 200},
 | 
				
			||||||
      {&align_K, MTL::DataType::DataTypeBool, 201},
 | 
					      {&align_K, MTL::DataType::DataTypeBool, 201},
 | 
				
			||||||
  };
 | 
					      {&has_mask, MTL::DataType::DataTypeBool, 300},
 | 
				
			||||||
 | 
					      {&do_causal, MTL::DataType::DataTypeBool, 301}};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  std::ostringstream kname;
 | 
					  std::ostringstream kname;
 | 
				
			||||||
  // clang-format off
 | 
					  // clang-format off
 | 
				
			||||||
@@ -54,13 +59,17 @@ void sdpa_full_self_attention_metal(
 | 
				
			|||||||
        << "_bq" << bq
 | 
					        << "_bq" << bq
 | 
				
			||||||
        << "_bk" << bk
 | 
					        << "_bk" << bk
 | 
				
			||||||
        << "_bd" << bd
 | 
					        << "_bd" << bd
 | 
				
			||||||
        << "_wm" << wm << "_wn" << wn; // clang-format on
 | 
					        << "_wm" << wm 
 | 
				
			||||||
 | 
					        << "_wn" << wn
 | 
				
			||||||
 | 
					        << "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  std::string base_name = kname.str();
 | 
					  std::string base_name = kname.str();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // clang-format off
 | 
					  // clang-format off
 | 
				
			||||||
  kname << "_align_Q_" << (align_Q ? 't' : 'n')
 | 
					  kname << "_align_Q_" << (align_Q ? 't' : 'n')
 | 
				
			||||||
        << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
 | 
					        << "_align_K_" << (align_K ? 't' : 'n') 
 | 
				
			||||||
 | 
					        << "_has_mask_" << (has_mask ? 't' : 'n')
 | 
				
			||||||
 | 
					        << "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  std::string hash_name = kname.str();
 | 
					  std::string hash_name = kname.str();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -91,6 +100,10 @@ void sdpa_full_self_attention_metal(
 | 
				
			|||||||
      /* int NQ_aligned = */ NQ_aligned,
 | 
					      /* int NQ_aligned = */ NQ_aligned,
 | 
				
			||||||
      /* int NK_aligned = */ NK_aligned,
 | 
					      /* int NK_aligned = */ NK_aligned,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      /* int qL_rem = */ (qL - NQ_aligned * bq),
 | 
				
			||||||
 | 
					      /* int kL_rem = */ (kL - NK_aligned * bk),
 | 
				
			||||||
 | 
					      /* int qL_off = */ (kL - qL),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
 | 
					      /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
 | 
				
			||||||
      /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
 | 
					      /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
 | 
				
			||||||
      /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
 | 
					      /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
 | 
				
			||||||
@@ -102,6 +115,15 @@ void sdpa_full_self_attention_metal(
 | 
				
			|||||||
  compute_encoder.set_output_array(o, 3);
 | 
					  compute_encoder.set_output_array(o, 3);
 | 
				
			||||||
  compute_encoder.set_bytes(params, 4);
 | 
					  compute_encoder.set_bytes(params, 4);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (mask) {
 | 
				
			||||||
 | 
					    auto m = *mask;
 | 
				
			||||||
 | 
					    AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
 | 
				
			||||||
 | 
					        m.strides(0), m.strides(1), m.strides(2)}};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    compute_encoder.set_bytes(mask_params, 5);
 | 
				
			||||||
 | 
					    compute_encoder.set_input_array(m, 6);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  MTL::Size grid_dims = MTL::Size(NQ, H, B);
 | 
					  MTL::Size grid_dims = MTL::Size(NQ, H, B);
 | 
				
			||||||
  MTL::Size group_dims = MTL::Size(32, wm, wn);
 | 
					  MTL::Size group_dims = MTL::Size(32, wm, wn);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -346,7 +368,7 @@ void ScaledDotProductAttention::eval_gpu(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  // Checks that the headdim dimension has stride 1.
 | 
					  // Checks that the headdim dimension has stride 1.
 | 
				
			||||||
  auto is_matrix_contiguous = [](const array& arr) {
 | 
					  auto is_matrix_contiguous = [](const array& arr) {
 | 
				
			||||||
    return arr.strides(3) == 1;
 | 
					    return arr.strides(-1) == 1;
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // We are in vector mode ie single query
 | 
					  // We are in vector mode ie single query
 | 
				
			||||||
@@ -415,7 +437,11 @@ void ScaledDotProductAttention::eval_gpu(
 | 
				
			|||||||
        {str_oB, str_oH, str_oL, str_oD},
 | 
					        {str_oB, str_oH, str_oL, str_oD},
 | 
				
			||||||
        flags);
 | 
					        flags);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
 | 
					    auto mask = inputs.size() > 3
 | 
				
			||||||
 | 
					        ? std::optional<array>{copy_unless(is_matrix_contiguous, inputs[3])}
 | 
				
			||||||
 | 
					        : std::nullopt;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o, do_causal_, mask);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  d.add_temporaries(std::move(copies), s.index);
 | 
					  d.add_temporaries(std::move(copies), s.index);
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										83
									
								
								mlx/fast.cpp
									
									
									
									
									
								
							
							
						
						
									
										83
									
								
								mlx/fast.cpp
									
									
									
									
									
								
							@@ -567,7 +567,7 @@ array scaled_dot_product_attention(
 | 
				
			|||||||
    const array& keys,
 | 
					    const array& keys,
 | 
				
			||||||
    const array& values,
 | 
					    const array& values,
 | 
				
			||||||
    const float scale,
 | 
					    const float scale,
 | 
				
			||||||
    const std::optional<array>& mask,
 | 
					    const std::variant<std::monostate, std::string, array>& mask /* = {}*/,
 | 
				
			||||||
    const std::optional<int> memory_efficient_threshold,
 | 
					    const std::optional<int> memory_efficient_threshold,
 | 
				
			||||||
    StreamOrDevice s) {
 | 
					    StreamOrDevice s) {
 | 
				
			||||||
  for (const auto& tensor : {queries, keys, values}) {
 | 
					  for (const auto& tensor : {queries, keys, values}) {
 | 
				
			||||||
@@ -578,10 +578,29 @@ array scaled_dot_product_attention(
 | 
				
			|||||||
      throw std::invalid_argument(msg.str());
 | 
					      throw std::invalid_argument(msg.str());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  if (mask && (*mask).ndim() > 4) {
 | 
					
 | 
				
			||||||
 | 
					  bool do_causal = false;
 | 
				
			||||||
 | 
					  bool has_mask = !std::holds_alternative<std::monostate>(mask);
 | 
				
			||||||
 | 
					  bool has_str_mask = has_mask && std::holds_alternative<std::string>(mask);
 | 
				
			||||||
 | 
					  bool has_arr_mask = has_mask && std::holds_alternative<array>(mask);
 | 
				
			||||||
 | 
					  bool has_bool_mask = false;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (has_str_mask) {
 | 
				
			||||||
 | 
					    if (std::get<std::string>(mask) != "causal") {
 | 
				
			||||||
 | 
					      std::ostringstream msg;
 | 
				
			||||||
 | 
					      msg << "[scaled_dot_product_attention] invalid mask option '"
 | 
				
			||||||
 | 
					          << std::get<std::string>(mask) << "'. Must be 'causal', or an array.";
 | 
				
			||||||
 | 
					      throw std::invalid_argument(msg.str());
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      do_causal = true;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (has_arr_mask && (std::get<array>(mask)).ndim() > 4) {
 | 
				
			||||||
    std::ostringstream msg;
 | 
					    std::ostringstream msg;
 | 
				
			||||||
    msg << "[scaled_dot_product_attention] the mask with shape "
 | 
					    msg << "[scaled_dot_product_attention] the mask with shape "
 | 
				
			||||||
        << (*mask).shape() << " expected to have at most rank 4";
 | 
					        << (std::get<array>(mask)).shape()
 | 
				
			||||||
 | 
					        << " expected to have at most rank 4";
 | 
				
			||||||
    throw std::invalid_argument(msg.str());
 | 
					    throw std::invalid_argument(msg.str());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -631,9 +650,11 @@ array scaled_dot_product_attention(
 | 
				
			|||||||
    throw std::invalid_argument(msg.str());
 | 
					    throw std::invalid_argument(msg.str());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (mask) {
 | 
					  if (has_arr_mask) {
 | 
				
			||||||
    // Check type
 | 
					    // Check type
 | 
				
			||||||
    if (promote_types(mask->dtype(), final_type) != final_type) {
 | 
					    auto mask_arr = std::get<array>(mask);
 | 
				
			||||||
 | 
					    has_bool_mask = mask_arr.dtype() == bool_;
 | 
				
			||||||
 | 
					    if (promote_types(mask_arr.dtype(), final_type) != final_type) {
 | 
				
			||||||
      std::ostringstream msg;
 | 
					      std::ostringstream msg;
 | 
				
			||||||
      msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
 | 
					      msg << "[scaled_dot_product_attention] Mask type must promote to output type. "
 | 
				
			||||||
          << final_type << ".";
 | 
					          << final_type << ".";
 | 
				
			||||||
@@ -642,9 +663,10 @@ array scaled_dot_product_attention(
 | 
				
			|||||||
    // Check shape
 | 
					    // Check shape
 | 
				
			||||||
    auto mask_shape = queries.shape();
 | 
					    auto mask_shape = queries.shape();
 | 
				
			||||||
    mask_shape.back() = keys.shape(-2);
 | 
					    mask_shape.back() = keys.shape(-2);
 | 
				
			||||||
    if (broadcast_shapes(mask->shape(), mask_shape) != mask_shape) {
 | 
					    if (broadcast_shapes(mask_arr.shape(), mask_shape) != mask_shape) {
 | 
				
			||||||
      std::ostringstream msg;
 | 
					      std::ostringstream msg;
 | 
				
			||||||
      msg << "[scaled_dot_product_attention] Mask with shape " << mask->shape()
 | 
					      msg << "[scaled_dot_product_attention] Mask with shape "
 | 
				
			||||||
 | 
					          << mask_arr.shape()
 | 
				
			||||||
          << " does not broadcast to implicit scores with shape " << mask_shape
 | 
					          << " does not broadcast to implicit scores with shape " << mask_shape
 | 
				
			||||||
          << ".";
 | 
					          << ".";
 | 
				
			||||||
      throw std::invalid_argument(msg.str());
 | 
					      throw std::invalid_argument(msg.str());
 | 
				
			||||||
@@ -662,7 +684,7 @@ array scaled_dot_product_attention(
 | 
				
			|||||||
    threshold = std::max(1, memory_efficient_threshold.value());
 | 
					    threshold = std::max(1, memory_efficient_threshold.value());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto fallback = [scale, final_type, n_q_heads, n_kv_heads, s](
 | 
					  auto fallback = [scale, final_type, n_q_heads, n_kv_heads, do_causal, s](
 | 
				
			||||||
                      const std::vector<array>& inputs) {
 | 
					                      const std::vector<array>& inputs) {
 | 
				
			||||||
    auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
 | 
					    auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
 | 
				
			||||||
    int n_repeats = n_q_heads / n_kv_heads;
 | 
					    int n_repeats = n_q_heads / n_kv_heads;
 | 
				
			||||||
@@ -676,9 +698,21 @@ array scaled_dot_product_attention(
 | 
				
			|||||||
      v = expand_dims(v, 2, s);
 | 
					      v = expand_dims(v, 2, s);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
 | 
					    auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
 | 
				
			||||||
    if (inputs.size() > 3) {
 | 
					    if (inputs.size() > 3 || do_causal) {
 | 
				
			||||||
      // Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
 | 
					      // Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv]
 | 
				
			||||||
      auto mask = inputs[3];
 | 
					      auto mask = inputs.back();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      if (do_causal) {
 | 
				
			||||||
 | 
					        int kL = k.shape(-2);
 | 
				
			||||||
 | 
					        int qL = q.shape(-2);
 | 
				
			||||||
 | 
					        int q_off = (kL - qL) < 0 ? 0 : (kL - qL);
 | 
				
			||||||
 | 
					        auto q_idx = arange(q_off, q_off + qL, s);
 | 
				
			||||||
 | 
					        auto k_idx = arange(0, kL, s);
 | 
				
			||||||
 | 
					        q_idx = expand_dims(q_idx, 1, s);
 | 
				
			||||||
 | 
					        k_idx = expand_dims(k_idx, 0, s);
 | 
				
			||||||
 | 
					        mask = greater_equal(q_idx, k_idx, s);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      if (n_repeats > 1 && mask.ndim() >= 3) {
 | 
					      if (n_repeats > 1 && mask.ndim() >= 3) {
 | 
				
			||||||
        if (mask.shape(-3) == 1) {
 | 
					        if (mask.shape(-3) == 1) {
 | 
				
			||||||
          mask = expand_dims(mask, -3, s);
 | 
					          mask = expand_dims(mask, -3, s);
 | 
				
			||||||
@@ -702,9 +736,10 @@ array scaled_dot_product_attention(
 | 
				
			|||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto stream = to_stream(s);
 | 
					  auto stream = to_stream(s);
 | 
				
			||||||
  const size_t value_head_dim = v.shape(-1);
 | 
					  const int value_head_dim = v.shape(-1);
 | 
				
			||||||
  const size_t query_head_dim = q.shape(-1);
 | 
					  const int query_head_dim = q.shape(-1);
 | 
				
			||||||
  const size_t query_sequence_length = q.shape(2);
 | 
					  const int query_sequence_length = q.shape(2);
 | 
				
			||||||
 | 
					  const int key_sequence_length = k.shape(2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const bool sdpa_vector_supported_head_dim =
 | 
					  const bool sdpa_vector_supported_head_dim =
 | 
				
			||||||
      query_head_dim == value_head_dim &&
 | 
					      query_head_dim == value_head_dim &&
 | 
				
			||||||
@@ -712,27 +747,33 @@ array scaled_dot_product_attention(
 | 
				
			|||||||
  const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
 | 
					  const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
 | 
				
			||||||
      (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
 | 
					      (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
 | 
					  const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask);
 | 
				
			||||||
      sdpa_full_supported_head_dim && stream.device == Device::gpu;
 | 
					  const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
 | 
				
			||||||
 | 
					      (query_sequence_length <= key_sequence_length && do_causal);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const bool supports_sdpa_full = query_sequence_length >= threshold &&
 | 
				
			||||||
 | 
					      sdpa_full_supported_mask && sdpa_full_supported_head_dim &&
 | 
				
			||||||
 | 
					      stream.device == Device::gpu;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
 | 
					  const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
 | 
				
			||||||
      (query_sequence_length <= k.shape(-2)) &&
 | 
					      (query_sequence_length <= key_sequence_length) &&
 | 
				
			||||||
      (!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim &&
 | 
					      sdpa_vector_supported_mask && sdpa_vector_supported_head_dim &&
 | 
				
			||||||
      stream.device == Device::gpu;
 | 
					      stream.device == Device::gpu;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const bool implementation_supports_use_case =
 | 
					  const bool implementation_supports_use_case =
 | 
				
			||||||
      supports_sdpa_full || supports_sdpa_vector;
 | 
					      supports_sdpa_full || supports_sdpa_vector;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  std::vector<array> inputs = {q, k, v};
 | 
					  std::vector<array> inputs = {q, k, v};
 | 
				
			||||||
  if (mask) {
 | 
					  if (has_arr_mask) {
 | 
				
			||||||
    inputs.push_back(*mask);
 | 
					    inputs.push_back(std::get<array>(mask));
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  if (implementation_supports_use_case) {
 | 
					  if (implementation_supports_use_case) {
 | 
				
			||||||
    auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
 | 
					    auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
 | 
				
			||||||
    return array(
 | 
					    return array(
 | 
				
			||||||
        std::move(out_shape),
 | 
					        std::move(out_shape),
 | 
				
			||||||
        final_type,
 | 
					        final_type,
 | 
				
			||||||
        std::make_shared<ScaledDotProductAttention>(stream, fallback, scale),
 | 
					        std::make_shared<ScaledDotProductAttention>(
 | 
				
			||||||
 | 
					            stream, fallback, scale, do_causal),
 | 
				
			||||||
        std::move(inputs));
 | 
					        std::move(inputs));
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  return fallback(inputs)[0];
 | 
					  return fallback(inputs)[0];
 | 
				
			||||||
@@ -741,7 +782,7 @@ array scaled_dot_product_attention(
 | 
				
			|||||||
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
 | 
					bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
 | 
				
			||||||
  const ScaledDotProductAttention& a_other =
 | 
					  const ScaledDotProductAttention& a_other =
 | 
				
			||||||
      static_cast<const ScaledDotProductAttention&>(other);
 | 
					      static_cast<const ScaledDotProductAttention&>(other);
 | 
				
			||||||
  return scale_ == a_other.scale_;
 | 
					  return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
array pack_and_quantize(
 | 
					array pack_and_quantize(
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,6 +3,7 @@
 | 
				
			|||||||
#pragma once
 | 
					#pragma once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <optional>
 | 
					#include <optional>
 | 
				
			||||||
 | 
					#include <variant>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mlx/utils.h"
 | 
					#include "mlx/utils.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -47,7 +48,7 @@ array scaled_dot_product_attention(
 | 
				
			|||||||
    const array& keys,
 | 
					    const array& keys,
 | 
				
			||||||
    const array& values,
 | 
					    const array& values,
 | 
				
			||||||
    const float scale,
 | 
					    const float scale,
 | 
				
			||||||
    const std::optional<array>& mask = std::nullopt,
 | 
					    const std::variant<std::monostate, std::string, array>& mask = {},
 | 
				
			||||||
    const std::optional<int> memory_efficient_threshold = std::nullopt,
 | 
					    const std::optional<int> memory_efficient_threshold = std::nullopt,
 | 
				
			||||||
    StreamOrDevice s = {});
 | 
					    StreamOrDevice s = {});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -206,8 +206,9 @@ class ScaledDotProductAttention : public Custom {
 | 
				
			|||||||
  explicit ScaledDotProductAttention(
 | 
					  explicit ScaledDotProductAttention(
 | 
				
			||||||
      Stream stream,
 | 
					      Stream stream,
 | 
				
			||||||
      std::function<std::vector<array>(std::vector<array>)> fallback,
 | 
					      std::function<std::vector<array>(std::vector<array>)> fallback,
 | 
				
			||||||
      const float scale)
 | 
					      const float scale,
 | 
				
			||||||
      : Custom(stream, fallback), scale_(scale) {}
 | 
					      const bool do_causal)
 | 
				
			||||||
 | 
					      : Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
 | 
					  void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
 | 
				
			||||||
      override {
 | 
					      override {
 | 
				
			||||||
@@ -225,12 +226,13 @@ class ScaledDotProductAttention : public Custom {
 | 
				
			|||||||
  DEFINE_PRINT(ScaledDotProductAttention);
 | 
					  DEFINE_PRINT(ScaledDotProductAttention);
 | 
				
			||||||
  DEFINE_INPUT_OUTPUT_SHAPE()
 | 
					  DEFINE_INPUT_OUTPUT_SHAPE()
 | 
				
			||||||
  auto state() const {
 | 
					  auto state() const {
 | 
				
			||||||
    return std::make_pair(nullptr, scale_);
 | 
					    return std::make_tuple(nullptr, scale_, do_causal_);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  std::function<std::vector<array>(std::vector<array>)> fallback_;
 | 
					  std::function<std::vector<array>(std::vector<array>)> fallback_;
 | 
				
			||||||
  float scale_;
 | 
					  float scale_;
 | 
				
			||||||
 | 
					  bool do_causal_;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AffineQuantize : public Custom {
 | 
					class AffineQuantize : public Custom {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -134,7 +134,7 @@ void init_fast(nb::module_& parent_module) {
 | 
				
			|||||||
      "memory_efficient_threshold"_a = nb::none(),
 | 
					      "memory_efficient_threshold"_a = nb::none(),
 | 
				
			||||||
      "stream"_a = nb::none(),
 | 
					      "stream"_a = nb::none(),
 | 
				
			||||||
      nb::sig(
 | 
					      nb::sig(
 | 
				
			||||||
          "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float,  mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
 | 
					          "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float,  mask: Union[None, str, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
 | 
				
			||||||
      R"pbdoc(
 | 
					      R"pbdoc(
 | 
				
			||||||
        A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
 | 
					        A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -164,11 +164,11 @@ void init_fast(nb::module_& parent_module) {
 | 
				
			|||||||
            k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
 | 
					            k (array): Keys with shape ``[B, N_kv, T_kv, D]``.
 | 
				
			||||||
            v (array): Values with shape ``[B, N_kv, T_kv, D]``.
 | 
					            v (array): Values with shape ``[B, N_kv, T_kv, D]``.
 | 
				
			||||||
            scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
 | 
					            scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``)
 | 
				
			||||||
            mask (array, optional): A boolean or additive mask to apply to the
 | 
					            mask (Union[None, str, array], optional): A causal, boolean or additive 
 | 
				
			||||||
               query-key scores. The mask can have at most 4 dimensions and must
 | 
					               mask to apply to the query-key scores. The mask can have at most 4 
 | 
				
			||||||
               be broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. If an
 | 
					               dimensions and must be broadcast-compatible with the shape 
 | 
				
			||||||
               additive mask is given its type must promote to the promoted
 | 
					               ``[B, N, T_q, T_kv]``. If an additive mask is given its type must 
 | 
				
			||||||
               type of ``q``, ``k``, and ``v``.
 | 
					               promote to the promoted type of ``q``, ``k``, and ``v``.
 | 
				
			||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
            array: The output array.
 | 
					            array: The output array.
 | 
				
			||||||
      )pbdoc");
 | 
					      )pbdoc");
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,6 +6,91 @@ import mlx_tests
 | 
				
			|||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
 | 
				
			||||||
 | 
					    q_dtype = q.dtype
 | 
				
			||||||
 | 
					    q = q * mx.array(scale, q_dtype)
 | 
				
			||||||
 | 
					    n_q_heads = q.shape[-3]
 | 
				
			||||||
 | 
					    n_kv_heads = k.shape[-3]
 | 
				
			||||||
 | 
					    n_repeats = n_q_heads // n_kv_heads
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    B = q.shape[0]
 | 
				
			||||||
 | 
					    L = q.shape[2]
 | 
				
			||||||
 | 
					    kL = k.shape[2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if n_repeats > 1:
 | 
				
			||||||
 | 
					        q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
 | 
				
			||||||
 | 
					        k = mx.expand_dims(k, 2)
 | 
				
			||||||
 | 
					        v = mx.expand_dims(v, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    scores = q @ mx.swapaxes(k, -1, -2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if mask is not None:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if mask == "causal":
 | 
				
			||||||
 | 
					            q_offset = max(0, kL - L)
 | 
				
			||||||
 | 
					            q_indices = mx.arange(q_offset, q_offset + L)
 | 
				
			||||||
 | 
					            k_indices = mx.arange(kL)
 | 
				
			||||||
 | 
					            mask = q_indices[:, None] >= k_indices[None]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if n_repeats > 1 and mask.ndim >= 3:
 | 
				
			||||||
 | 
					            if mask.shape[-3] == 1:
 | 
				
			||||||
 | 
					                mask = mx.expand_dims(mask, -3)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if mask.dtype == mx.bool_:
 | 
				
			||||||
 | 
					            scores = mx.where(mask, scores, -np.float32(np.inf))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            scores += mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    scores = mx.softmax(scores, axis=-1, precise=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    out = scores @ v
 | 
				
			||||||
 | 
					    if n_repeats > 1:
 | 
				
			||||||
 | 
					        out = mx.reshape(out, [B, n_q_heads, L, -1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def do_attention(f, q, k, v, scale, mask=None, transpose=False):
 | 
				
			||||||
 | 
					    if transpose:
 | 
				
			||||||
 | 
					        q_t = mx.transpose(q, (0, 2, 1, 3))
 | 
				
			||||||
 | 
					        k_t = mx.transpose(k, (0, 2, 1, 3))
 | 
				
			||||||
 | 
					        v_t = mx.transpose(v, (0, 2, 1, 3))
 | 
				
			||||||
 | 
					        o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
 | 
				
			||||||
 | 
					        return mx.transpose(o_t, (0, 2, 1, 3))
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return f(q, k, v, scale=scale, mask=mask)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
 | 
				
			||||||
 | 
					    np.random.seed(0)
 | 
				
			||||||
 | 
					    np_dtype = getattr(np, dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
 | 
				
			||||||
 | 
					    shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    scale = 1.0 / math.sqrt(D)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
 | 
				
			||||||
 | 
					    k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
 | 
				
			||||||
 | 
					    v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    q_mx = mx.array(q_np)
 | 
				
			||||||
 | 
					    k_mx = mx.array(k_np)
 | 
				
			||||||
 | 
					    v_mx = mx.array(v_np)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if mask is not None:
 | 
				
			||||||
 | 
					        if mask == "additive":
 | 
				
			||||||
 | 
					            mask_np = np.random.normal(0.0, 0.5, (B, qH, qL, kL)).astype(np_dtype)
 | 
				
			||||||
 | 
					            mask = mx.array(mask_np)
 | 
				
			||||||
 | 
					        elif mask == "bool":
 | 
				
			||||||
 | 
					            mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
 | 
				
			||||||
 | 
					            mask = mx.array(mask_np)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return q_mx, k_mx, v_mx, scale, mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# SDPA for MHA (n_heads == n_kv_heads)
 | 
					# SDPA for MHA (n_heads == n_kv_heads)
 | 
				
			||||||
def mlx_primitives_sdpa(q, k, v, scale, mask=None):
 | 
					def mlx_primitives_sdpa(q, k, v, scale, mask=None):
 | 
				
			||||||
    p = (q * scale) @ k.transpose(0, 1, 3, 2)
 | 
					    p = (q * scale) @ k.transpose(0, 1, 3, 2)
 | 
				
			||||||
@@ -365,5 +450,84 @@ class TestFastSDPA(mlx_tests.MLXTestCase):
 | 
				
			|||||||
            self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
 | 
					            self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestSDPA(mlx_tests.MLXTestCase):
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def dtypes(self):
 | 
				
			||||||
 | 
					        return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_sdpa(self):
 | 
				
			||||||
 | 
					        if not mx.metal.is_available():
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # fmt: off
 | 
				
			||||||
 | 
					        shapes_64 = (
 | 
				
			||||||
 | 
					            # (  B,   qsl,   ksl, head_dim, n_qh, n_kvh)
 | 
				
			||||||
 | 
					            (  1,   128,   128,       64,   32,    32),
 | 
				
			||||||
 | 
					            (  1,    64,   128,       64,   32,    32),
 | 
				
			||||||
 | 
					            (  1,    65,   128,       64,   32,     8),
 | 
				
			||||||
 | 
					            (  1,    64,   127,       64,   32,     8),
 | 
				
			||||||
 | 
					            (  1,    65,   127,       64,   32,     8),
 | 
				
			||||||
 | 
					            (  1,   127,    65,       64,   32,     8),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        shapes_128 = (
 | 
				
			||||||
 | 
					            # (  B,   qsl,   ksl, head_dim, n_qh, n_kvh)
 | 
				
			||||||
 | 
					            (  1,   128,   128,      128,   32,     8),
 | 
				
			||||||
 | 
					            (  1,    64,   128,      128,   32,     8),
 | 
				
			||||||
 | 
					            (  1,    65,   127,      128,   32,     8),
 | 
				
			||||||
 | 
					            (  1,   127,    65,      128,   32,     8),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # fmt: on
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        shapes = shapes_64 + shapes_128
 | 
				
			||||||
 | 
					        masks = [None, "additive", "bool", "causal"]
 | 
				
			||||||
 | 
					        transposes = (False, True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for dtype in self.dtypes:
 | 
				
			||||||
 | 
					            for t in transposes:
 | 
				
			||||||
 | 
					                for mask_str in masks:
 | 
				
			||||||
 | 
					                    for B, qL, kL, D, qH, kH in shapes:
 | 
				
			||||||
 | 
					                        with self.subTest(
 | 
				
			||||||
 | 
					                            B=B,
 | 
				
			||||||
 | 
					                            qsl=qL,
 | 
				
			||||||
 | 
					                            ksl=kL,
 | 
				
			||||||
 | 
					                            head_dim=D,
 | 
				
			||||||
 | 
					                            n_q_heads=qH,
 | 
				
			||||||
 | 
					                            n_kv_heads=kH,
 | 
				
			||||||
 | 
					                            mask=mask_str,
 | 
				
			||||||
 | 
					                            transpose=t,
 | 
				
			||||||
 | 
					                            dtype=dtype,
 | 
				
			||||||
 | 
					                        ):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                            np.random.seed(0)
 | 
				
			||||||
 | 
					                            q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
 | 
				
			||||||
 | 
					                                B, qL, kL, D, qH, kH, mask_str, t, dtype
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                            out_ref = do_attention(
 | 
				
			||||||
 | 
					                                mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, t
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                            out_fst = do_attention(
 | 
				
			||||||
 | 
					                                mx.fast.scaled_dot_product_attention,
 | 
				
			||||||
 | 
					                                q_mx,
 | 
				
			||||||
 | 
					                                k_mx,
 | 
				
			||||||
 | 
					                                v_mx,
 | 
				
			||||||
 | 
					                                scale,
 | 
				
			||||||
 | 
					                                mask,
 | 
				
			||||||
 | 
					                                t,
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                            atol = 2e-5 if dtype == "float32" else 3e-4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                            self.assertListEqual(
 | 
				
			||||||
 | 
					                                list(out_ref.shape), list(out_fst.shape)
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                            self.assertTrue(
 | 
				
			||||||
 | 
					                                mx.allclose(out_fst, out_ref, atol=atol, rtol=atol)
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    unittest.main(failfast=True)
 | 
					    unittest.main(failfast=True)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user