From 1865299a30d2c47fb4497e3dd01800965b21b088 Mon Sep 17 00:00:00 2001 From: Brian Keene Date: Mon, 3 Jun 2024 12:16:19 -0400 Subject: [PATCH] Metal shaders for memory efficient self attention on large sequences (#964) * Metal shaders for efficient self attention on large sequences Updated fast attention: GEMM-ified with Steel primitives Uses flash attention 1 for scale correction * more compiler silencing * Address rebase issues * Templatize kernel instantiation, revise cpu bindings * Safer writes to output * Permit batch size > 1 * Numerical fixes for sdpa self attention * Re-enable test, remove unused variable * add benchmarking script * Disable sdpa prior to perf tuning, and simplify tests for per-patch CI --- benchmarks/python/sdpa_bench.py | 64 ++ .../scaled_dot_product_attention.metal | 917 ++++++++++++++++++ .../scaled_dot_product_attention_params.h | 28 + .../metal/scaled_dot_product_attention.cpp | 140 +++ mlx/fast.cpp | 30 +- python/src/fast.cpp | 1 - python/tests/test_fast_sdpa.py | 73 +- 7 files changed, 1244 insertions(+), 9 deletions(-) create mode 100644 benchmarks/python/sdpa_bench.py diff --git a/benchmarks/python/sdpa_bench.py b/benchmarks/python/sdpa_bench.py new file mode 100644 index 000000000..fcaad7b6a --- /dev/null +++ b/benchmarks/python/sdpa_bench.py @@ -0,0 +1,64 @@ +import argparse +import math + +import mlx.core as mx +from time_utils import time_fn + +MAX_SEQ = 300 +START_SEQ = 100 +SEQ_INCREMENT = 50 + + +def time_self_attention_primitives(): + + mx.random.seed(3) + B = 2 + H = 38 + D = 64 + for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT): + q = mx.random.uniform(shape=(B, H, R, D)) + k = mx.random.uniform(shape=(B, H, R, D)) + v = mx.random.uniform(shape=(B, H, R, D)) + scale = 1.0 / math.sqrt(float(D)) + mx.eval(q, k, v) + + def sdpa_primitives(qs, ks, vs, alpha): + s = (alpha * qs) @ ks.transpose(0, 1, 3, 2) + p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) + o = p @ vs + return o + + time_fn(sdpa_primitives, q, k, v, scale) + + +def time_self_attention_sdpa(): + + mx.random.seed(3) + B = 2 + H = 38 + D = 64 + for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT): + q = mx.random.uniform(shape=(B, H, R, D)) + k = mx.random.uniform(shape=(B, H, R, D)) + v = mx.random.uniform(shape=(B, H, R, D)) + scale = 1.0 / math.sqrt(float(D)) + mx.eval(q, k, v) + + def sdpa_fused(qs, ks, vs, alpha): + o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha) + return o + + time_fn(sdpa_fused, q, k, v, scale) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("MLX benchmarks.") + parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") + args = parser.parse_args() + if args.gpu: + mx.set_default_device(mx.gpu) + else: + mx.set_default_device(mx.cpu) + + time_self_attention_sdpa() + time_self_attention_primitives() diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index 1d8d9f9d6..a7cccc710 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -1,9 +1,926 @@ #include #include +#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + #include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" using namespace metal; +using namespace mlx::steel; + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short alignment = 1, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoaderFA { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + struct alignas(alignment * sizeof(T)) ReadVector { + uint8_t v[sizeof(T) * vec_size]; + }; + + /* Constructor */ + METAL_FUNC BlockLoaderFA( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + *((threadgroup ReadVector*)(&dst[i * dst_ld])) = + *((const device ReadVector*)(&src[i * src_ld])); + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } + METAL_FUNC void next(short n) { + src += n * tile_stride; + } +}; + +template +struct LoopAlignment {}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMAFA { + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = 8 * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = 8 * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / TM_stride; + // Warp tile size along N + STEEL_CONST short TN = BN / TN_stride; + + // Strides of A, B along reduction axis + STEEL_CONST short simd_stride_a = { + transpose_a ? TM_stride : TM_stride * lda_tgp}; + STEEL_CONST short simd_stride_b = { + transpose_b ? TN_stride * ldb_tgp : TN_stride}; + + // Jump between elements + STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; + STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; + + STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; + STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; + + // Simdgroup matrices + simdgroup_matrix Asimd[TM]; + simdgroup_matrix Bsimd[TN]; + simdgroup_matrix results[TM * TN] = { + simdgroup_matrix(0)}; + + // Offsets within threadgroup + const short tm; + const short tn; + + short sm; + short sn; + + ushort sid; + ushort slid; + + short As_offset; + short Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMAFA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { + // Determine thread position in simdgroup matrix + short qid = simd_lane_id / 4; + slid = simd_lane_id; + sid = simd_group_id; + + sm = (qid & 4) + (simd_lane_id / 2) % 4; + sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + + // Determine thread and simdgroup offset + As_offset = + transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); + Bs_offset = + transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of 8 + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += 8) { + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup A as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + Asimd[i].thread_elements()[0] = + static_cast(As[i * simd_stride_a + 0]); + Asimd[i].thread_elements()[1] = + static_cast(As[i * simd_stride_a + jump_a]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup B as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + Bsimd[j].thread_elements()[0] = + static_cast(Bs[j * simd_stride_b + 0]); + Bsimd[j].thread_elements()[1] = + static_cast(Bs[j * simd_stride_b + jump_b]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Multiply and accumulate into result simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + short j_serp = (i % 2) ? (TN - 1 - j) : j; + + simdgroup_multiply_accumulate( + results[i * TN + j_serp], + Asimd[i], + Bsimd[j_serp], + results[i * TN + j_serp]); + } + } + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } + } + + METAL_FUNC void rescale_output(const threadgroup float* Corrections) { + // Loop over all simdgroup tiles + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + short row = sm + tm + i * TM_stride; + float scale_value = Corrections[row]; + + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = results[i * TN + j].thread_elements(); + // int offset = (i * TM_stride) * ldc + (j * TN_stride); + accum[0] *= scale_value; + accum[1] *= scale_value; + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* C, const int ldc) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldc + (j * TN_stride); + + // Apply epilogue + U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; + + // Write out C + C[offset] = outs[0]; + C[offset + 1] = outs[1]; + } + } + } + + METAL_FUNC void store_result_to_tgp_memory( + threadgroup U* C, + const int ldc, + short2 dst_tile_dims) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn); + dst_tile_dims -= short2(tn + sn, sm + tm); + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldc + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + C[offset] = Epilogue::apply(accum[0]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + C[offset + 1] = Epilogue::apply(accum[1]); + } + } + } + } + } + + METAL_FUNC void + store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn); + dst_tile_dims -= short2(tn + sn, sm + tm); + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldc + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + C[offset] = Epilogue::apply(accum[0]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + C[offset + 1] = Epilogue::apply(accum[1]); + } + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + U outs[2] = { + epilogue_op.apply(accum[0], C[offset_c]), + epilogue_op.apply(accum[1], C[offset_c + fdc])}; + + // Write out D + D[offset_d] = outs[0]; + D[offset_d + 1] = outs[1]; + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + dst_tile_dims -= short2(tn + sn, sm + tm); + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + } + } + } + } + } + + METAL_FUNC void clear_results() { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + results[i * TN + j] = simdgroup_matrix(0); + } + } + } +}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_q, + bool transpose_k, + bool transpose_v, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct FastAttentionKernel { + STEEL_CONST short tgp_padding = 16 / sizeof(T); + STEEL_CONST short float_padding = 16 / sizeof(float); + STEEL_CONST short tgp_mem_size_q = + transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding); + STEEL_CONST short tgp_mem_size_k = + transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); + STEEL_CONST short tgp_mem_size_v = + transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); + STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding); + + // maxes, rowsums, rescale + STEEL_CONST short tgp_mem_size_corrections = + 4 * (BM * sizeof(float) + float_padding); + + STEEL_CONST bool share_kv_smem = transpose_k != transpose_v; + + STEEL_CONST short tgp_mem_size = share_kv_smem + ? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + + tgp_mem_size_corrections + : tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + + tgp_mem_size_corrections + tgp_mem_size_v; + + STEEL_CONST short tgp_size = WM * WN * 32; + + static_assert(transpose_q == false, "Expected Q not transposed."); + static_assert(transpose_k == true, "Expected K transposed."); + static_assert(transpose_v == false, "Expected V not transposed."); + static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested."); + + using loader_q_t = BlockLoaderFA< + T, + transpose_q ? BK : BM, + transpose_q ? BM : BK, + transpose_q ? BM + tgp_padding : BK + tgp_padding, + !transpose_q, + tgp_size>; + + using loader_k_t = BlockLoaderFA< + T, + transpose_k ? BN : BK, + transpose_k ? BK : BN, + transpose_k ? BK + tgp_padding : BN + tgp_padding, + transpose_k, + tgp_size>; + + using loader_v_t = BlockLoaderFA< + T, + transpose_v ? BK : BN, + transpose_v ? BN : BK, + transpose_v ? BN + tgp_padding : BK + tgp_padding, + transpose_v, + tgp_size>; + + using mma_qk_t = BlockMMAFA< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_q, + transpose_k, + transpose_q ? BM + tgp_padding : BK + tgp_padding, + transpose_k ? BK + tgp_padding : BN + tgp_padding, + AccumType, + Epilogue>; + + using mma_sv_t = BlockMMAFA< + T, + U, + BM, + BK, + BN, + WM, + WN, + false, + transpose_v, + BN + tgp_padding, + BK + tgp_padding, + AccumType, + Epilogue>; + + /* Main kernel function */ + template + static METAL_FUNC void gemm_loop( + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + const int gemm_k_iterations, + thread loader_k_t& loader_b, + thread mma_qk_t& mma_op, + thread const short& tgp_bm, + thread const short& tgp_bn, + LoopAlignment l = {}) { + // Appease the compiler + (void)l; + (void)tgp_bm; + + short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + // not valid for gemm_k_iterations > 1 (so, BK == d_k) + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + } + + static METAL_FUNC void initialize_corrections( + threadgroup float* C, + uint simd_lane_id, + uint simd_group_id) { + if (simd_group_id == 0) { + threadgroup float* maxes = C; + threadgroup float* sums = C + (BM + float_padding); + threadgroup float* o_rescale = sums + (BM + float_padding); + threadgroup float* output_rescale = o_rescale + (BM + float_padding); + + if (simd_lane_id < BM) { + maxes[simd_lane_id] = -INFINITY; // m_i + sums[simd_lane_id] = 0.f; // l_i + o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new) + output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i + } + } + } + + static METAL_FUNC void rescale_ss( + threadgroup T* Ss, + threadgroup float* Corrections, + uint simd_group_id, + uint simd_lane_id, + short2 local_blocks, + float alpha) { + if (simd_group_id == 0) { + short row_offset = BM + float_padding; + threadgroup float* maxes = Corrections; + threadgroup float* sums = Corrections + row_offset; + threadgroup float* o_rescale = sums + row_offset; + threadgroup float* output_scales = o_rescale + row_offset; + + if (simd_lane_id < uint(local_blocks.y)) { + float m_i_old = maxes[simd_lane_id]; + float l_i_old = sums[simd_lane_id]; + + float m_i_new = m_i_old; + float l_i_new = l_i_old; + + short offset = simd_lane_id * (BN + tgp_padding); + + float m_ij = -INFINITY; + + for (short j = 0; j < local_blocks.x; j++) { + float val = alpha * float(Ss[offset + j]); + m_ij = max(m_ij, val); + } + + m_i_new = max(m_ij, m_i_new); + + float rowsum = 0.f; // lij + + for (short j = 0; j < local_blocks.x; j++) { + float val = alpha * float(Ss[offset + j]); + float P_i_j = exp(val - m_ij); + rowsum += P_i_j; + P_i_j = P_i_j * exp(m_ij - m_i_new); + Ss[offset + j] = T(P_i_j); + } + + l_i_new = + exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum; + maxes[simd_lane_id] = m_i_new; + sums[simd_lane_id] = l_i_new; + float rescale = l_i_old * exp(m_i_old - m_i_new); + o_rescale[simd_lane_id] = rescale; + output_scales[simd_lane_id] = 1.0 / l_i_new; + } + } + } + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device U* O [[buffer(3)]], + const constant MLXFastAttentionParams* params [[buffer(4)]], + threadgroup T* Qs [[threadgroup(0)]], + threadgroup T* Ks [[threadgroup(1)]], + threadgroup T* Ss [[threadgroup(2)]], + threadgroup T* Vs [[threadgroup(3)]], + threadgroup float* Corrections [[threadgroup(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in Q, O; and head in K, V. + const int c_row = tid_y * BM; + + Q += transpose_q ? c_row : c_row * params->ldq; + thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id); + + short tgp_bm = min(BM, params->M - c_row); + short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + + loader_q.load_safe(tile_dims_Q); + + initialize_corrections(Corrections, simd_lane_id, simd_group_id); + + O += c_row * params->ldo; + + // Prepare threadgroup mma operation + thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id); + thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id); + thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id); + thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id); + + for (short n_block = 0; n_block < params->gemm_n_iterations_aligned; + n_block++) { + short c_col = BN; + + // Prepare threadgroup loading operations + short gemm_k_iterations = params->gemm_k_iterations_aligned; + short tgp_bn_qk = min(BN, params->N - c_col * n_block); + threadgroup_barrier(mem_flags::mem_none); + + /////////////////////////////////////////////////////////////////////////////// + { // Loop over K - unaligned case + + if (tgp_bm == BM && tgp_bn_qk == BN) { + gemm_loop( + Qs, + Ks, + gemm_k_iterations, + loader_k, + mma_qk_op, + tgp_bm, + tgp_bn_qk); + } else if (tgp_bn_qk == BN) { + gemm_loop( + Qs, + Ks, + gemm_k_iterations, + loader_k, + mma_qk_op, + tgp_bm, + tgp_bn_qk); + + } else if (tgp_bm == BM) { + gemm_loop( + Qs, + Ks, + gemm_k_iterations, + loader_k, + mma_qk_op, + tgp_bm, + tgp_bn_qk); + + } else { + gemm_loop( + Qs, + Ks, + gemm_k_iterations, + loader_k, + mma_qk_op, + tgp_bm, + tgp_bn_qk); + } + } + + mma_qk_op.store_result_to_tgp_memory( + Ss, BN + tgp_padding, short2(BN, BM)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + rescale_ss( + Ss, + Corrections, + simd_group_id, + simd_lane_id, + short2(tgp_bn_qk, tgp_bm), + params->alpha); + + loader_v.load_safe(short2(BK, tgp_bn_qk)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup float* o_scales = Corrections + 2 * (BM + float_padding); + mma_softmax_sv_op.rescale_output(o_scales); + + mma_softmax_sv_op.mma(Ss, Vs); + + threadgroup float* final_output_scales = + Corrections + 3 * (BM + float_padding); + + mma_softmax_sv_op.rescale_output(final_output_scales); + + loader_v.next(); + loader_k.next(BN); + + mma_qk_op.clear_results(); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm)); + } +}; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_q, + bool transpose_k, + bool transpose_v, + bool MN_aligned, + bool K_aligned> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device T* O [[buffer(3)]], + const constant MLXFastAttentionParams* params [[buffer(4)]], + const constant int* batch_shape [[buffer(6)]], + const constant size_t* batch_strides [[buffer(7)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using attention_kernel = FastAttentionKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_q, + transpose_k, + transpose_v, + MN_aligned, + K_aligned>; + + // Adjust for batch + if (params->batch_ndim > 1) { + const constant size_t* Q_bstrides = batch_strides; + const constant size_t* KV_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim); + + Q += batch_offsets.x; + K += batch_offsets.y; + V += batch_offsets.y; + + } else { + Q += params->batch_stride_q * tid.z; + K += params->batch_stride_k * tid.z; + V += params->batch_stride_v * tid.z; + } + + // same shape as input + O += params->batch_stride_o * tid.z; + threadgroup T Qs[attention_kernel::tgp_mem_size_q]; + threadgroup T Ss[attention_kernel::tgp_mem_size_s]; + threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections]; + + if (attention_kernel::share_kv_smem) { + threadgroup T Ks[attention_kernel::tgp_mem_size_k]; + threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v]; + attention_kernel::run( + Q, + K, + V, + O, + params, + Qs, + Ks, + Ss, + Vs, + Corrections, + simd_lane_id, + simd_group_id, + tid, + lid); + } else { + threadgroup T Ks[attention_kernel::tgp_mem_size_k]; + threadgroup T Vs[attention_kernel::tgp_mem_size_v]; + attention_kernel::run( + Q, + K, + V, + O, + params, + Qs, + Ks, + Ss, + Vs, + Corrections, + simd_lane_id, + simd_group_id, + tid, + lid); + } +} + +#define instantiate_fast_inference_self_attention_kernel( \ + itype, otype, bm, bn, bk, wm, wn) \ + template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \ + "_itype_" #itype)]] [[kernel]] void \ + attention( \ + const device itype* Q [[buffer(0)]], \ + const device itype* K [[buffer(1)]], \ + const device itype* V [[buffer(2)]], \ + device otype* O [[buffer(3)]], \ + const constant MLXFastAttentionParams* params [[buffer(4)]], \ + const constant int* batch_shape [[buffer(6)]], \ + const constant size_t* batch_strides [[buffer(7)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +instantiate_fast_inference_self_attention_kernel( + float, + float, + 16, + 16, + 64, + 2, + 2); +instantiate_fast_inference_self_attention_kernel( + float, + float, + 16, + 16, + 128, + 2, + 2); +instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2); +instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2); + template < typename T, typename T2, diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention_params.h b/mlx/backend/metal/kernels/scaled_dot_product_attention_params.h index 09b9defb8..a77dad268 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention_params.h +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention_params.h @@ -4,6 +4,34 @@ #pragma once +struct MLXFastAttentionParams { + const int M; + const int N; + const int K; + + const int ldq; // ldq == ldo + const int ldk; + const int ldv; + const int lds; + const int ldo; + + const int tiles_n; + const int tiles_m; + + const int batch_stride_q; + const int batch_stride_k; + const int batch_stride_v; + const int batch_stride_o; + + const int swizzle_log; + const int gemm_n_iterations_aligned; + const int gemm_k_iterations_aligned; + const int gemm_sv_m_block_iterations; + + const int batch_ndim; + const float alpha; +}; + struct MLXScaledDotProductAttentionParams { // Associated dimensions & transposition information const uint QUERY_SEQUENCE_LENGTH = 1; diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 0ded93397..92d3ee05a 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -19,6 +19,140 @@ namespace mlx::core::fast { namespace { +void sdpa_full_self_attention_metal( + const Stream& s, + metal::Device& d, + const array& q, + const array& k, + const array& v, + const float alpha, + array& out, + std::vector& temporaries) { + std::ostringstream kname_self_attention; + kname_self_attention << "steel_gemm_attention_"; + + constexpr const int bm = 16; + constexpr const int bn = 16; + const int bk = q.shape(-1); // already forced to be 64 or 128 + + if (bk != 64 && bk != 128) { + throw std::runtime_error( + "[ScaledDotProductAttention::eval_gpu]: hidden dim: expected either 64, 128"); + } + + constexpr const int wm = 2; + constexpr const int wn = 2; + + std::string delimiter = "_"; + + kname_self_attention << "bm_" + std::to_string(bm) + delimiter; + kname_self_attention << "bn_" + std::to_string(bn) + delimiter; + kname_self_attention << "bk_" + std::to_string(bk) + delimiter; + + for (const auto& arr : {k, v, out}) { + if (arr.dtype() != q.dtype()) { + throw std::runtime_error( + "[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o"); + } + } + + if (q.dtype() == float32) { + kname_self_attention << "itype" + delimiter + "float"; + } else if (q.dtype() == float16) { + kname_self_attention << "itype" + delimiter + "half"; + } else { + throw std::runtime_error( + "[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16."); + } + + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname_self_attention.str()); + compute_encoder->setComputePipelineState(kernel); + + uint hidden_dim = q.shape(-1); + uint qseq = q.shape(-2); + uint qheads = q.shape(-3); + + const uint64_t KV_sequence_length = k.shape(-2); + const uint query_sequence_length = q.shape(-2); + const uint n_q_heads = q.shape(1); + const uint n_kv_heads = k.shape(1); + + const int M = q.shape(-2); + const int N = M; + const int K = q.shape(-1); + const size_t batch_size_out = q.shape(0) * q.shape(1); + + const std::vector batch_shape = {q.shape(0) * q.shape(1)}; + const int dk = q.shape(-1); + const int ldq = dk; + const int ldk = dk; + const int ldv = dk; + const int lds = bn; + const int ldo = dk; + + int tn = 1; + int tm = (M + bm - 1) / bm; + + const int batch_stride_q = dk * query_sequence_length; + const int batch_stride_k = dk * query_sequence_length; + const int batch_stride_v = dk * query_sequence_length; + const int batch_stride_o = dk * query_sequence_length; + const int swizzle_log = 0; + const int gemm_n_iterations_aligned = (N + bn - 1) / bn; + const int gemm_k_iterations_aligned = (K + bk - 1) / bk; + const int gemm_sv_m_block_iterations = (M + bm - 1) / bm; + const int batch_ndim = int(batch_shape.size()); + + MLXFastAttentionParams params{ + (int)M, + (int)N, + (int)K, + ldq, + ldk, + ldv, + lds, + ldo, + tn, + tm, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o, + swizzle_log, + gemm_n_iterations_aligned, + gemm_k_iterations_aligned, + gemm_sv_m_block_iterations, + batch_ndim, + alpha}; + + const std::vector batch_strides = { + (size_t)batch_stride_q, + (size_t)batch_stride_k, + (size_t)batch_stride_v, + (size_t)batch_stride_o}; + + compute_encoder.set_input_array(q, 0); + compute_encoder.set_input_array(k, 1); + compute_encoder.set_input_array(v, 2); + compute_encoder.set_output_array(out, 3); + + compute_encoder->setBytes(¶ms, sizeof(MLXFastAttentionParams), 4); + compute_encoder->setBytes( + batch_shape.data(), sizeof(int) * batch_shape.size(), 6); + + compute_encoder->setBytes( + batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7); + + MTL::Size grid_dims = MTL::Size(1, tm, batch_size_out); + MTL::Size group_dims = MTL::Size(32, wm, wn); + + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + + d.get_command_buffer(s.index)->addCompletedHandler( + [temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); }); + return; +} void sdpa_metal( const Stream& s, @@ -170,6 +304,12 @@ void ScaledDotProductAttention::eval_gpu( auto v = check_transpose(v_pre); const int heads = q.shape(-3); + + uint query_sequence_length = q.shape(-2); + if (query_sequence_length >= 16) { + return sdpa_full_self_attention_metal( + s, d, q, k, v, scale_, out, temporaries); + } int tile_size = 64; const int kv_seq_len = k.shape(-2); if (kv_seq_len > 8000) { diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 52e7c8c21..0a0bd2066 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -527,11 +527,13 @@ array scaled_dot_product_attention( /* generic implementation for use cases that Metal implementation does not * support. For non-supported cases listed below, use MLX primitives: * * CPU implementation - * * batch size > 1 - * * query sequence length > 1 + * * batch size > 1 for decoding or causal attention + * * query sequence length > 1 for decoding + * * query sequence length > 16 && non-null mask (causal attention) * * non-null mask * * dtype is not fp32 or fp16 */ + bool needs_mask = mask.has_value(); auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s]( const std::vector& inputs) { @@ -559,15 +561,29 @@ array scaled_dot_product_attention( }; auto stream = to_stream(s); - constexpr const int supported_head_dim = 128; const size_t query_head_dim = q.shape(-1); + const bool supported_head_dim = + query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128; + + const bool supported_head_dim_self_attn = + query_head_dim == 64 || query_head_dim == 128; const size_t query_sequence_length = q.shape(2); - bool implementation_supports_use_case = batch_dim == 1 && - query_sequence_length == 1 && !mask.has_value() && - query_head_dim == supported_head_dim && final_type != bfloat16 && + const bool supports_full_self_attention = query_sequence_length >= 16 && + !mask.has_value() && supported_head_dim_self_attn && + n_q_heads == n_kv_heads && final_type != bfloat16 && stream.device == Device::gpu; - // TODO, update routing conditions post further tuning + + // fast decoding gpu shader + bool supports_sdpa = batch_dim == 1 && query_sequence_length == 1 && + !mask.has_value() && supported_head_dim && final_type != bfloat16 && + stream.device == Device::gpu; + bool implementation_supports_use_case = + supports_sdpa || supports_full_self_attention; + + // disabling full self attention until perf is tuned; + // likewise for sdpa implementation_supports_use_case &= false; + if (implementation_supports_use_case) { auto out_shape = std::vector({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}); diff --git a/python/src/fast.cpp b/python/src/fast.cpp index d74eca6e8..f729b76fc 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -135,7 +135,6 @@ void init_fast(nb::module_& parent_module) { v (array): Input values array. scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``) mask (array, optional): An additive mask to apply to the query-key scores. - Returns: array: The output array. )pbdoc"); diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index b4fa07395..51b2c047c 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -32,9 +32,80 @@ def mlx_primitives_sdpa_with_gqa(q, k, v, scale): return mlx_primitives_sdpa(q, k, v, scale) -class TestFastSDPA(mlx_tests.MLXTestCase): +class TestFastSelfAttentionSDPA(mlx_tests.MLXTestCase): def test_fast_sdpa(self): + # Not yet supported: + # * K pre-transposed in kernel, V pre-transposed in kernel + np.random.seed(0) + R = 20 + L = R + Dk = 64 + H = 3 + scale = float(1.0 / np.sqrt(Dk)) + q_npy = np.random.normal(0.0, 1.0, (1, H, R, Dk)).astype(np.float32) + k_npy = np.random.normal(0.0, 1.0, (1, H, L, Dk)).astype(np.float32) + v_npy = np.random.normal(0.0, 1.0, (1, H, L, Dk)).astype(np.float32) + + q_mlx = mx.array(q_npy) + k_mlx = mx.array(k_npy) + v_mlx = mx.array(v_npy) + + reference = mlx_primitives_sdpa(q_mlx, k_mlx, v_mlx, scale) + + o_mlx = mx.fast.scaled_dot_product_attention( + q_mlx, k_mlx, v_mlx, scale=scale, mask=None + ) + + self.assertListEqual(list(reference.shape), list(o_mlx.shape)) + self.assertTrue(mx.allclose(o_mlx, reference, atol=1e-4)) + + dtypes = [np.float32] + + Dk = 64 + + if self.is_apple_silicon: + dtypes.append(np.half) + + for SEQUENCE_LENGTH in [63, 129, 400]: + for DTYPE in dtypes: + B = 2 + H = 24 + n_kv_heads = H + q_npy = np.random.normal(0.0, 1.0, (B, H, SEQUENCE_LENGTH, Dk)).astype( + DTYPE + ) + k_npy = np.random.normal( + 0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk) + ).astype(DTYPE) + v_npy = np.random.normal( + 0.0, 1.0, (B, n_kv_heads, SEQUENCE_LENGTH, Dk) + ).astype(DTYPE) + + q_mlx = mx.array(q_npy) + k_mlx = mx.array(k_npy) + v_mlx = mx.array(v_npy) + + reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale) + o_mlx = mx.fast.scaled_dot_product_attention( + q_mlx, k_mlx, v_mlx, scale=scale + ) + + self.assertListEqual(list(reference.shape), list(o_mlx.shape)) + rtol = 1e-3 + atol = 1e-2 + + if SEQUENCE_LENGTH > 500: + rtol = 1e-2 + + if DTYPE == np.half: + rtol = 1e-2 + + self.assertTrue(mx.allclose(o_mlx, reference, rtol=rtol, atol=atol)) + + +class TestFastSDPA(mlx_tests.MLXTestCase): + def test_fast_sdpa(self): # Not yet supported: # * K pre-transposed in kernel, V pre-transposed in kernel np.random.seed(0)