From 9adcd1a650b6f97b9d9e6e07e96c09e437184e9a Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Thu, 20 Mar 2025 11:01:32 -0700 Subject: [PATCH] Support fused masking in Attention (#1924) * Update API to allow mask='causal' in fast::sdpa * Add fallback * Update steel::AttnParams * Fix typo * WIP, basic causal * Update tests * Update benchmarking * Update masking loop limits * Add bool masking and update tests * Update additive mask * Update benchmarks * Update benchmarks * Update tests * Update for bfloat error * Update early exit * Add random seed to tests --- benchmarks/python/sdpa_bench.py | 194 ++++++++++-------- .../steel/attn/kernels/steel_attention.h | 103 +++++++++- .../steel/attn/kernels/steel_attention.metal | 37 ++-- mlx/backend/metal/kernels/steel/attn/mma.h | 2 +- mlx/backend/metal/kernels/steel/attn/params.h | 8 + .../metal/scaled_dot_product_attention.cpp | 38 +++- mlx/fast.cpp | 83 ++++++-- mlx/fast.h | 3 +- mlx/fast_primitives.h | 8 +- python/src/fast.cpp | 12 +- python/tests/test_fast_sdpa.py | 164 +++++++++++++++ 11 files changed, 504 insertions(+), 148 deletions(-) diff --git a/benchmarks/python/sdpa_bench.py b/benchmarks/python/sdpa_bench.py index 23383475e..5eb789de0 100644 --- a/benchmarks/python/sdpa_bench.py +++ b/benchmarks/python/sdpa_bench.py @@ -28,11 +28,34 @@ def bench(f, *args): return (e - s) * 1e-9 -def mlx_sdpa_fused_inner(q, k, v, scale): - return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None) +def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype): + np_dtype = getattr(np, dtype) + + shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D) + shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D) + + scale = 1.0 / math.sqrt(D) + + q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype) + k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype) + v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype) + + q_mx = mx.array(q_np) + k_mx = mx.array(k_np) + v_mx = mx.array(v_np) + + if mask is not None: + if mask == "additive": + mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype) + mask = mx.array(mask_np) + elif mask == "bool": + mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5 + mask = mx.array(mask_np) + + return q_mx, k_mx, v_mx, scale, mask -def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False): +def mlx_ref_attn(q, k, v, scale=1.0, mask=None): q_dtype = q.dtype q = q * mx.array(scale, q_dtype) n_q_heads = q.shape[-3] @@ -41,6 +64,7 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False): B = q.shape[0] L = q.shape[2] + kL = k.shape[2] if n_repeats > 1: q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1]) @@ -48,10 +72,27 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False): v = mx.expand_dims(v, 2) scores = q @ mx.swapaxes(k, -1, -2) - if f32softmax: - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype) - else: - scores = mx.softmax(scores, axis=-1) + + if mask is not None: + + if mask == "causal": + q_offset = max(0, kL - L) + q_indices = mx.arange(q_offset, q_offset + L) + k_indices = mx.arange(kL) + mask = q_indices[:, None] >= k_indices[None] + + if n_repeats > 1 and mask.ndim >= 3: + if mask.shape[-3] == 1: + mask = mx.expand_dims(mask, -3) + else: + mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats)) + + if mask.dtype == mx.bool_: + scores = mx.where(mask, scores, -np.float32(np.inf)) + else: + scores += mask + + scores = mx.softmax(scores, axis=-1, precise=True) out = scores @ v if n_repeats > 1: @@ -60,74 +101,55 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False): return out -def mlx_spda_unfused(q, k, v, scale, transpose): - q_out = q +def mlx_fused_attn(q, k, v, scale, mask): + return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) + + +def do_attention(f, q, k, v, scale, mask=None, transpose=False): if transpose: - k = mx.transpose(k, (0, 2, 1, 3)) - v = mx.transpose(v, (0, 2, 1, 3)) + q_t = mx.transpose(q, (0, 2, 1, 3)) + k_t = mx.transpose(k, (0, 2, 1, 3)) + v_t = mx.transpose(v, (0, 2, 1, 3)) + o_t = f(q_t, k_t, v_t, scale=scale, mask=mask) + return mx.transpose(o_t, (0, 2, 1, 3)) + else: + return f(q, k, v, scale=scale, mask=mask) + + +def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False): + q_out = q for i in range(N_iter_func): - if transpose: - q_out = mx.transpose(q_out, (0, 2, 1, 3)) - q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale) - if transpose: - q_out = mx.transpose(q_out, (0, 2, 1, 3)) + q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose) mx.eval(q_out) return q_out -def mlx_spda_fused(q, k, v, scale, transpose): - q_out = q - if transpose: - k = mx.transpose(k, (0, 2, 1, 3)) - v = mx.transpose(v, (0, 2, 1, 3)) - - for i in range(N_iter_func): - if transpose: - q_out = mx.transpose(q_out, (0, 2, 1, 3)) - q_out = mlx_sdpa_fused_inner(q_out, k, v, scale) - if transpose: - q_out = mx.transpose(q_out, (0, 2, 1, 3)) - - mx.eval(q_out) - return q_out - - -def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True): - shape_q = ( - (B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim) - ) - shape_kv = ( - (B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim) +def bench_shape( + B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None +): + q_mx, k_mx, v_mx, scale, mask = prepare_inputs( + B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype ) - q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype) - k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype) - v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype) + time_mlx_unfused = bench( + do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose + ) + time_mlx_fused = bench( + do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose + ) - scale = math.sqrt(1.0 / head_dim) + o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose) + o_mlx_unfused = do_attention( + mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose + ) - q_mx = mx.array(q_np) - k_mx = mx.array(k_np) - v_mx = mx.array(v_np) + atol = 1e-5 if dtype == "float32" else 2e-4 - time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose) - time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose) - - if transpose: - q_mx = mx.transpose(q_mx, (0, 2, 1, 3)) - k_mx = mx.transpose(k_mx, (0, 2, 1, 3)) - v_mx = mx.transpose(v_mx, (0, 2, 1, 3)) - - o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale) - o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True) - - atol = 1e-5 if np_dtype == np.float32 else 1e-4 - - if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol): + if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol): print( - f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}" + f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}" ) return time_mlx_fused, time_mlx_unfused @@ -151,39 +173,51 @@ if __name__ == "__main__": ( 1, 128, 128, 64, 32, 32), ( 1, 256, 256, 64, 32, 32), ( 1, 512, 512, 64, 32, 32), - ( 1, 1024, 1024, 64, 32, 32), - ( 1, 2048, 2048, 64, 32, 32), - ( 1, 4096, 4096, 64, 32, 32), + ( 1, 1024, 1024, 64, 32, 8), + ( 1, 2048, 2048, 64, 32, 8), + ( 1, 4096, 4096, 64, 32, 8), ) shapes_80 = ( # ( B, qsl, ksl, head_dim, n_qh, n_kvh) - ( 1, 1024, 1024, 80, 32, 32), - ( 1, 2048, 2048, 80, 32, 32), - ( 1, 4096, 4096, 80, 32, 32), + ( 1, 1024, 1024, 80, 32, 8), + ( 1, 2048, 2048, 80, 32, 8), + ( 1, 4096, 4096, 80, 32, 8), ) shapes_128 = ( # ( B, qsl, ksl, head_dim, n_qh, n_kvh) - ( 1, 1024, 1024, 128, 32, 32), - ( 1, 2048, 2048, 128, 32, 32), - ( 1, 4096, 4096, 128, 32, 32), + ( 1, 1024, 1024, 128, 32, 8), + ( 1, 2048, 2048, 128, 32, 8), + ( 1, 4096, 4096, 128, 32, 8), ) # fmt: on shapes = shapes_64 + shapes_80 + shapes_128 - print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%") + masks = [None, "bool", "causal"] + + print( + " B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%" + ) for dtype in dtypes: for transpose in transposes: for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes: - np_dtype = getattr(np, dtype) - time_mlx_fused, time_mlx_unfused = bench_shape( - B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose - ) - diff = time_mlx_unfused / time_mlx_fused - 1.0 - t_str = 1 if transpose else 0 - print( - f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%" - ) + for mask_in in masks: + time_mlx_fused, time_mlx_unfused = bench_shape( + B, + qsl, + ksl, + head_dim, + n_q_heads, + n_kv_heads, + dtype, + transpose, + mask_in, + ) + diff = time_mlx_unfused / time_mlx_fused - 1.0 + t_str = 1 if transpose else 0 + print( + f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%" + ) diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index b2e70ef8d..a8469e0ff 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -1,4 +1,4 @@ -// Copyright © 2024 Apple Inc. +// Copyright © 2024-25 Apple Inc. using namespace mlx::steel; @@ -9,6 +9,9 @@ using namespace mlx::steel; constant bool align_Q [[function_constant(200)]]; constant bool align_K [[function_constant(201)]]; +constant bool has_mask [[function_constant(300)]]; +constant bool do_causal [[function_constant(301)]]; + template struct TransformScale { T scale; @@ -69,6 +72,7 @@ template < int BD, int WM, int WN, + typename MaskType = float, typename AccumType = float> [[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention( const device T* Q [[buffer(0)]], @@ -76,6 +80,8 @@ template < const device T* V [[buffer(2)]], device T* O [[buffer(3)]], const constant AttnParams* params [[buffer(4)]], + const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], + const device MaskType* mask [[buffer(6), function_constant(has_mask)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], @@ -102,6 +108,11 @@ template < tidl.y * params->O_strides[1] + // Head tidl.x * BQ * params->O_strides[2]; // Seqeunce + if (has_mask) { + mask += tidl.z * mask_params->M_strides[0] + // Batch + tidl.y * mask_params->M_strides[1]; // Head + } + // Prepare threadgroup memory constexpr short padQ = 16 / sizeof(T); constexpr short padK = 16 / sizeof(T); @@ -203,7 +214,7 @@ template < // Load Q blocks apply scale if (!align_Q && int(tid.x) == (params->NQ_aligned)) { - loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ)); + loader_q.load_safe(short2(BD, params->qL_rem)); } else { loader_q.load_unsafe(); } @@ -221,12 +232,19 @@ template < max_score[i] = Limits::min; } + int kb_lim = params->NK; + + if (do_causal) { + int q_max = (tid.x + 1) * BQ + params->qL_off; + kb_lim = (q_max + BK - 1) / BK; + } + // Loop over KV seq length - for (int kb = 0; kb < params->NK; kb++) { + for (int kb = 0; kb < kb_lim; kb++) { // Load K block and apply scale threadgroup_barrier(mem_flags::mem_threadgroup); if (!align_K && kb == (params->NK_aligned)) { - loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK)); + loader_k.load_safe(short2(BD, params->kL_rem)); } else { loader_k.load_unsafe(); } @@ -250,12 +268,11 @@ template < tile_matmad(Stile, Qtile, Ktile, Stile); } - // Mask out of length sequence + // Mask out length sequence if (!align_K && kb == (params->NK_aligned)) { using stile_t = decltype(Stile); using selem_t = typename stile_t::elem_type; constexpr auto neg_inf = -metal::numeric_limits::infinity(); - const short lim = params->kL - params->NK_aligned * BK; STEEL_PRAGMA_UNROLL for (short i = 0; i < stile_t::kTileRows; i++) { @@ -264,7 +281,7 @@ template < short col_pos = sn + (j * stile_t::kFragCols); STEEL_PRAGMA_UNROLL for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { - if ((col_pos + jj) >= lim) { + if ((col_pos + jj) >= params->kL_rem) { Stile.frag_at(i, j)[jj] = neg_inf; } } @@ -272,11 +289,78 @@ template < } } + // Mask out if causal + if (do_causal && kb >= (kb_lim - (BQ + BK - 1) / BK - int(!align_K))) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = -metal::numeric_limits::infinity(); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + const int row_pos = + tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if (row_pos < (col_pos + jj)) { + Stile.frag_at(i, j)[jj] = neg_inf; + } + } + } + } + } + + // Other masking as needed + if (has_mask) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = -metal::numeric_limits::infinity(); + + constexpr bool is_bool = is_same_v; + using melem_t = typename metal::conditional_t; + + using MMAFrag_mask_t = BaseMMAFrag; + using frag_t = typename MMAFrag_mask_t::frag_type; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); + + frag_t mfrag; + + MMAFrag_mask_t::load_safe( + mfrag, + mask, + int(mask_params->M_strides[2]), + Int<1>{}, + params->qL, + params->kL, + row_pos, + col_pos); + + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) { + if constexpr (is_bool) { + Stile.frag_at(i, j)[jj] = + mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf; + } else { + Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]); + } + } + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); // Load V blocks if (!align_K && kb == (params->NK_aligned)) { - loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK)); + loader_v.load_safe(short2(BD, params->kL_rem)); } else { loader_v.load_unsafe(); } @@ -367,8 +451,7 @@ template < O += (tm + sm) * params->O_strides[2] + sn; if (!align_Q && int(tid.x) == (params->NQ_aligned)) { - auto dst_tile_dims = - short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm)); + auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm)); if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) return; diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal index 0d05a6932..fee28fed1 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal @@ -1,4 +1,4 @@ -// Copyright © 2024 Apple Inc. +// Copyright © 2024-25 Apple Inc. // clang-format off #include "mlx/backend/metal/kernels/utils.h" @@ -6,26 +6,23 @@ #include "mlx/backend/metal/kernels/steel/attn/attn.h" #include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h" -#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \ - template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \ - [[kernel]] void attention( \ - const device dtype* Q [[buffer(0)]], \ - const device dtype* K [[buffer(1)]], \ - const device dtype* V [[buffer(2)]], \ - device dtype* O [[buffer(3)]],\ - const constant AttnParams* params [[buffer(4)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); +#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \ + instantiate_kernel( \ + "steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \ + "_wm" #wm "_wn" #wn "_mask" #mname, \ + attention, dtype, bq, bk, bd, wm, wn, mtype, float) -#define instantiate_attn_shapes_helper(iname, itype) \ - instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \ - instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \ - instantiate_attn(iname, itype, 32, 32, 64, 4, 1) +#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ + instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype) -instantiate_attn_shapes_helper(float16, half); -instantiate_attn_shapes_helper(bfloat16, bfloat16_t); +#define instantiate_attn_mask_helper(iname, itype) \ + instantiate_attn_shapes_helper(iname, itype, iname, itype) \ + instantiate_attn_shapes_helper(iname, itype, bool_, bool) -instantiate_attn_shapes_helper(float32, float); +instantiate_attn_mask_helper(float16, half); +instantiate_attn_mask_helper(bfloat16, bfloat16_t); + +instantiate_attn_mask_helper(float32, float); // clang-format on diff --git a/mlx/backend/metal/kernels/steel/attn/mma.h b/mlx/backend/metal/kernels/steel/attn/mma.h index 525c50e8f..db5127c33 100644 --- a/mlx/backend/metal/kernels/steel/attn/mma.h +++ b/mlx/backend/metal/kernels/steel/attn/mma.h @@ -111,7 +111,7 @@ struct BaseMMAFrag { for (short j = 0; j < kElemCols; j++) { if ((off_x + i) < lim_x && (off_y + j) < lim_y) { dst[i * kElemCols + j] = - static_cast(src[(off_x + i) * str_x + (off_x + j) * str_y]); + static_cast(src[(off_x + i) * str_x + (off_y + j) * str_y]); } else { dst[i * kElemCols + j] = T(0); } diff --git a/mlx/backend/metal/kernels/steel/attn/params.h b/mlx/backend/metal/kernels/steel/attn/params.h index 4f9680412..f1cf09fad 100644 --- a/mlx/backend/metal/kernels/steel/attn/params.h +++ b/mlx/backend/metal/kernels/steel/attn/params.h @@ -26,11 +26,19 @@ struct AttnParams { int NQ_aligned; ///< Number of full query blocks int NK_aligned; ///< Number of full key/value blocks + int qL_rem; ///< Remainder in last query block + int kL_rem; ///< Remainder in last key/value block + int qL_off; ///< Offset in query sequence start + int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1) int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1) int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1) }; +struct AttnMaskParams { + int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1) +}; + } // namespace steel } // namespace mlx diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 7fbd63022..f7ec004a6 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -21,7 +21,9 @@ void sdpa_full_self_attention_metal( const array& k, const array& v, const float scale, - array& o) { + array& o, + bool do_causal_ = false, + const std::optional& mask = std::nullopt) { using namespace mlx::steel; int wm = 4; @@ -41,11 +43,14 @@ void sdpa_full_self_attention_metal( const bool align_Q = (qL % bq) == 0; const bool align_K = (kL % bk) == 0; + const bool has_mask = !!mask; + const bool do_causal = do_causal_; metal::MTLFCList func_consts = { {&align_Q, MTL::DataType::DataTypeBool, 200}, {&align_K, MTL::DataType::DataTypeBool, 201}, - }; + {&has_mask, MTL::DataType::DataTypeBool, 300}, + {&do_causal, MTL::DataType::DataTypeBool, 301}}; std::ostringstream kname; // clang-format off @@ -54,13 +59,17 @@ void sdpa_full_self_attention_metal( << "_bq" << bq << "_bk" << bk << "_bd" << bd - << "_wm" << wm << "_wn" << wn; // clang-format on + << "_wm" << wm + << "_wn" << wn + << "_mask" << (type_to_name(has_mask ? *mask : q)); // clang-format on std::string base_name = kname.str(); // clang-format off kname << "_align_Q_" << (align_Q ? 't' : 'n') - << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on + << "_align_K_" << (align_K ? 't' : 'n') + << "_has_mask_" << (has_mask ? 't' : 'n') + << "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on std::string hash_name = kname.str(); @@ -91,6 +100,10 @@ void sdpa_full_self_attention_metal( /* int NQ_aligned = */ NQ_aligned, /* int NK_aligned = */ NK_aligned, + /* int qL_rem = */ (qL - NQ_aligned * bq), + /* int kL_rem = */ (kL - NK_aligned * bk), + /* int qL_off = */ (kL - qL), + /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, @@ -102,6 +115,15 @@ void sdpa_full_self_attention_metal( compute_encoder.set_output_array(o, 3); compute_encoder.set_bytes(params, 4); + if (mask) { + auto m = *mask; + AttnMaskParams mask_params{/* int64_t M_strides[3] = */ { + m.strides(0), m.strides(1), m.strides(2)}}; + + compute_encoder.set_bytes(mask_params, 5); + compute_encoder.set_input_array(m, 6); + } + MTL::Size grid_dims = MTL::Size(NQ, H, B); MTL::Size group_dims = MTL::Size(32, wm, wn); @@ -346,7 +368,7 @@ void ScaledDotProductAttention::eval_gpu( // Checks that the headdim dimension has stride 1. auto is_matrix_contiguous = [](const array& arr) { - return arr.strides(3) == 1; + return arr.strides(-1) == 1; }; // We are in vector mode ie single query @@ -415,7 +437,11 @@ void ScaledDotProductAttention::eval_gpu( {str_oB, str_oH, str_oL, str_oD}, flags); - sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o); + auto mask = inputs.size() > 3 + ? std::optional{copy_unless(is_matrix_contiguous, inputs[3])} + : std::nullopt; + + sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o, do_causal_, mask); } d.add_temporaries(std::move(copies), s.index); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 136c7796a..342078a24 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -567,7 +567,7 @@ array scaled_dot_product_attention( const array& keys, const array& values, const float scale, - const std::optional& mask, + const std::variant& mask /* = {}*/, const std::optional memory_efficient_threshold, StreamOrDevice s) { for (const auto& tensor : {queries, keys, values}) { @@ -578,10 +578,29 @@ array scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } } - if (mask && (*mask).ndim() > 4) { + + bool do_causal = false; + bool has_mask = !std::holds_alternative(mask); + bool has_str_mask = has_mask && std::holds_alternative(mask); + bool has_arr_mask = has_mask && std::holds_alternative(mask); + bool has_bool_mask = false; + + if (has_str_mask) { + if (std::get(mask) != "causal") { + std::ostringstream msg; + msg << "[scaled_dot_product_attention] invalid mask option '" + << std::get(mask) << "'. Must be 'causal', or an array."; + throw std::invalid_argument(msg.str()); + } else { + do_causal = true; + } + } + + if (has_arr_mask && (std::get(mask)).ndim() > 4) { std::ostringstream msg; msg << "[scaled_dot_product_attention] the mask with shape " - << (*mask).shape() << " expected to have at most rank 4"; + << (std::get(mask)).shape() + << " expected to have at most rank 4"; throw std::invalid_argument(msg.str()); } @@ -631,9 +650,11 @@ array scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } - if (mask) { + if (has_arr_mask) { // Check type - if (promote_types(mask->dtype(), final_type) != final_type) { + auto mask_arr = std::get(mask); + has_bool_mask = mask_arr.dtype() == bool_; + if (promote_types(mask_arr.dtype(), final_type) != final_type) { std::ostringstream msg; msg << "[scaled_dot_product_attention] Mask type must promote to output type. " << final_type << "."; @@ -642,9 +663,10 @@ array scaled_dot_product_attention( // Check shape auto mask_shape = queries.shape(); mask_shape.back() = keys.shape(-2); - if (broadcast_shapes(mask->shape(), mask_shape) != mask_shape) { + if (broadcast_shapes(mask_arr.shape(), mask_shape) != mask_shape) { std::ostringstream msg; - msg << "[scaled_dot_product_attention] Mask with shape " << mask->shape() + msg << "[scaled_dot_product_attention] Mask with shape " + << mask_arr.shape() << " does not broadcast to implicit scores with shape " << mask_shape << "."; throw std::invalid_argument(msg.str()); @@ -662,7 +684,7 @@ array scaled_dot_product_attention( threshold = std::max(1, memory_efficient_threshold.value()); } - auto fallback = [scale, final_type, n_q_heads, n_kv_heads, s]( + auto fallback = [scale, final_type, n_q_heads, n_kv_heads, do_causal, s]( const std::vector& inputs) { auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s); int n_repeats = n_q_heads / n_kv_heads; @@ -676,9 +698,21 @@ array scaled_dot_product_attention( v = expand_dims(v, 2, s); } auto scores = matmul(q, swapaxes(k, -1, -2, s), s); - if (inputs.size() > 3) { + if (inputs.size() > 3 || do_causal) { // Mask must be broadcast-compatible with [B, n_q_heads, L_q, L_kv] - auto mask = inputs[3]; + auto mask = inputs.back(); + + if (do_causal) { + int kL = k.shape(-2); + int qL = q.shape(-2); + int q_off = (kL - qL) < 0 ? 0 : (kL - qL); + auto q_idx = arange(q_off, q_off + qL, s); + auto k_idx = arange(0, kL, s); + q_idx = expand_dims(q_idx, 1, s); + k_idx = expand_dims(k_idx, 0, s); + mask = greater_equal(q_idx, k_idx, s); + } + if (n_repeats > 1 && mask.ndim() >= 3) { if (mask.shape(-3) == 1) { mask = expand_dims(mask, -3, s); @@ -702,9 +736,10 @@ array scaled_dot_product_attention( }; auto stream = to_stream(s); - const size_t value_head_dim = v.shape(-1); - const size_t query_head_dim = q.shape(-1); - const size_t query_sequence_length = q.shape(2); + const int value_head_dim = v.shape(-1); + const int query_head_dim = q.shape(-1); + const int query_sequence_length = q.shape(2); + const int key_sequence_length = k.shape(2); const bool sdpa_vector_supported_head_dim = query_head_dim == value_head_dim && @@ -712,27 +747,33 @@ array scaled_dot_product_attention( const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); - const bool supports_sdpa_full = query_sequence_length >= threshold && !mask && - sdpa_full_supported_head_dim && stream.device == Device::gpu; + const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask); + const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || + (query_sequence_length <= key_sequence_length && do_causal); + + const bool supports_sdpa_full = query_sequence_length >= threshold && + sdpa_full_supported_mask && sdpa_full_supported_head_dim && + stream.device == Device::gpu; const bool supports_sdpa_vector = (query_sequence_length <= 8) && - (query_sequence_length <= k.shape(-2)) && - (!mask || mask->dtype() == bool_) && sdpa_vector_supported_head_dim && + (query_sequence_length <= key_sequence_length) && + sdpa_vector_supported_mask && sdpa_vector_supported_head_dim && stream.device == Device::gpu; const bool implementation_supports_use_case = supports_sdpa_full || supports_sdpa_vector; std::vector inputs = {q, k, v}; - if (mask) { - inputs.push_back(*mask); + if (has_arr_mask) { + inputs.push_back(std::get(mask)); } if (implementation_supports_use_case) { auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; return array( std::move(out_shape), final_type, - std::make_shared(stream, fallback, scale), + std::make_shared( + stream, fallback, scale, do_causal), std::move(inputs)); } return fallback(inputs)[0]; @@ -741,7 +782,7 @@ array scaled_dot_product_attention( bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const { const ScaledDotProductAttention& a_other = static_cast(other); - return scale_ == a_other.scale_; + return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_; } array pack_and_quantize( diff --git a/mlx/fast.h b/mlx/fast.h index fe93de85e..b9db6d462 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include "mlx/utils.h" @@ -47,7 +48,7 @@ array scaled_dot_product_attention( const array& keys, const array& values, const float scale, - const std::optional& mask = std::nullopt, + const std::variant& mask = {}, const std::optional memory_efficient_threshold = std::nullopt, StreamOrDevice s = {}); diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index ec97fe0ca..4d9e505ee 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -206,8 +206,9 @@ class ScaledDotProductAttention : public Custom { explicit ScaledDotProductAttention( Stream stream, std::function(std::vector)> fallback, - const float scale) - : Custom(stream, fallback), scale_(scale) {} + const float scale, + const bool do_causal) + : Custom(stream, fallback), scale_(scale), do_causal_(do_causal) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -225,12 +226,13 @@ class ScaledDotProductAttention : public Custom { DEFINE_PRINT(ScaledDotProductAttention); DEFINE_INPUT_OUTPUT_SHAPE() auto state() const { - return std::make_pair(nullptr, scale_); + return std::make_tuple(nullptr, scale_, do_causal_); } private: std::function(std::vector)> fallback_; float scale_; + bool do_causal_; }; class AffineQuantize : public Custom { diff --git a/python/src/fast.cpp b/python/src/fast.cpp index fc2cbd41d..95b7dcc9a 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -134,7 +134,7 @@ void init_fast(nb::module_& parent_module) { "memory_efficient_threshold"_a = nb::none(), "stream"_a = nb::none(), nb::sig( - "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), + "def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, str, array] = None, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( A fast implementation of multi-head attention: ``O = softmax(Q @ K.T, dim=-1) @ V``. @@ -164,11 +164,11 @@ void init_fast(nb::module_& parent_module) { k (array): Keys with shape ``[B, N_kv, T_kv, D]``. v (array): Values with shape ``[B, N_kv, T_kv, D]``. scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) - mask (array, optional): A boolean or additive mask to apply to the - query-key scores. The mask can have at most 4 dimensions and must - be broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. If an - additive mask is given its type must promote to the promoted - type of ``q``, ``k``, and ``v``. + mask (Union[None, str, array], optional): A causal, boolean or additive + mask to apply to the query-key scores. The mask can have at most 4 + dimensions and must be broadcast-compatible with the shape + ``[B, N, T_q, T_kv]``. If an additive mask is given its type must + promote to the promoted type of ``q``, ``k``, and ``v``. Returns: array: The output array. )pbdoc"); diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 5426ea236..a269847de 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -6,6 +6,91 @@ import mlx_tests import numpy as np +def mlx_ref_attn(q, k, v, scale=1.0, mask=None): + q_dtype = q.dtype + q = q * mx.array(scale, q_dtype) + n_q_heads = q.shape[-3] + n_kv_heads = k.shape[-3] + n_repeats = n_q_heads // n_kv_heads + + B = q.shape[0] + L = q.shape[2] + kL = k.shape[2] + + if n_repeats > 1: + q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1]) + k = mx.expand_dims(k, 2) + v = mx.expand_dims(v, 2) + + scores = q @ mx.swapaxes(k, -1, -2) + + if mask is not None: + + if mask == "causal": + q_offset = max(0, kL - L) + q_indices = mx.arange(q_offset, q_offset + L) + k_indices = mx.arange(kL) + mask = q_indices[:, None] >= k_indices[None] + + if n_repeats > 1 and mask.ndim >= 3: + if mask.shape[-3] == 1: + mask = mx.expand_dims(mask, -3) + else: + mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats)) + + if mask.dtype == mx.bool_: + scores = mx.where(mask, scores, -np.float32(np.inf)) + else: + scores += mask + + scores = mx.softmax(scores, axis=-1, precise=True) + + out = scores @ v + if n_repeats > 1: + out = mx.reshape(out, [B, n_q_heads, L, -1]) + + return out + + +def do_attention(f, q, k, v, scale, mask=None, transpose=False): + if transpose: + q_t = mx.transpose(q, (0, 2, 1, 3)) + k_t = mx.transpose(k, (0, 2, 1, 3)) + v_t = mx.transpose(v, (0, 2, 1, 3)) + o_t = f(q_t, k_t, v_t, scale=scale, mask=mask) + return mx.transpose(o_t, (0, 2, 1, 3)) + else: + return f(q, k, v, scale=scale, mask=mask) + + +def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype): + np.random.seed(0) + np_dtype = getattr(np, dtype) + + shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D) + shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D) + + scale = 1.0 / math.sqrt(D) + + q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype) + k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype) + v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype) + + q_mx = mx.array(q_np) + k_mx = mx.array(k_np) + v_mx = mx.array(v_np) + + if mask is not None: + if mask == "additive": + mask_np = np.random.normal(0.0, 0.5, (B, qH, qL, kL)).astype(np_dtype) + mask = mx.array(mask_np) + elif mask == "bool": + mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5 + mask = mx.array(mask_np) + + return q_mx, k_mx, v_mx, scale, mask + + # SDPA for MHA (n_heads == n_kv_heads) def mlx_primitives_sdpa(q, k, v, scale, mask=None): p = (q * scale) @ k.transpose(0, 1, 3, 2) @@ -365,5 +450,84 @@ class TestFastSDPA(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(ref, out, atol=1e-4, rtol=1e-4)) +class TestSDPA(mlx_tests.MLXTestCase): + @property + def dtypes(self): + return ["float32", "float16"] if mx.metal.is_available() else ["float32"] + + def test_sdpa(self): + if not mx.metal.is_available(): + return + + # fmt: off + shapes_64 = ( + # ( B, qsl, ksl, head_dim, n_qh, n_kvh) + ( 1, 128, 128, 64, 32, 32), + ( 1, 64, 128, 64, 32, 32), + ( 1, 65, 128, 64, 32, 8), + ( 1, 64, 127, 64, 32, 8), + ( 1, 65, 127, 64, 32, 8), + ( 1, 127, 65, 64, 32, 8), + ) + + shapes_128 = ( + # ( B, qsl, ksl, head_dim, n_qh, n_kvh) + ( 1, 128, 128, 128, 32, 8), + ( 1, 64, 128, 128, 32, 8), + ( 1, 65, 127, 128, 32, 8), + ( 1, 127, 65, 128, 32, 8), + ) + # fmt: on + + shapes = shapes_64 + shapes_128 + masks = [None, "additive", "bool", "causal"] + transposes = (False, True) + + for dtype in self.dtypes: + for t in transposes: + for mask_str in masks: + for B, qL, kL, D, qH, kH in shapes: + with self.subTest( + B=B, + qsl=qL, + ksl=kL, + head_dim=D, + n_q_heads=qH, + n_kv_heads=kH, + mask=mask_str, + transpose=t, + dtype=dtype, + ): + + np.random.seed(0) + q_mx, k_mx, v_mx, scale, mask = prepare_inputs( + B, qL, kL, D, qH, kH, mask_str, t, dtype + ) + + out_ref = do_attention( + mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, t + ) + + out_fst = do_attention( + mx.fast.scaled_dot_product_attention, + q_mx, + k_mx, + v_mx, + scale, + mask, + t, + ) + + atol = 2e-5 if dtype == "float32" else 3e-4 + + self.assertListEqual( + list(out_ref.shape), list(out_fst.shape) + ) + + self.assertTrue( + mx.allclose(out_fst, out_ref, atol=atol, rtol=atol) + ) + + if __name__ == "__main__": unittest.main(failfast=True)