From 990b1acc75efba56fa56ee9b27d3b2f424514414 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Thu, 21 Nov 2024 10:33:45 -0800 Subject: [PATCH] Remove older fast attention code. Write out O strided --- mlx/backend/metal/kernels/CMakeLists.txt | 4 +- .../scaled_dot_product_attention.metal | 919 ------------------ .../scaled_dot_product_attention_params.h | 42 - .../metal/scaled_dot_product_attention.cpp | 21 +- 4 files changed, 19 insertions(+), 967 deletions(-) delete mode 100644 mlx/backend/metal/kernels/scaled_dot_product_attention_params.h diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 936127e20c..e6416581ab 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -44,9 +44,7 @@ build_kernel(layer_norm) build_kernel(random) build_kernel(rms_norm) build_kernel(rope) -build_kernel( - scaled_dot_product_attention scaled_dot_product_attention_params.h - sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h) +build_kernel(scaled_dot_product_attention sdpa_vector.h) set(STEEL_HEADERS steel/defines.h diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index 0142312850..b5bc9607e7 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -1,930 +1,11 @@ #include -#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" #include "mlx/backend/metal/kernels/sdpa_vector.h" -#include "mlx/backend/metal/kernels/steel/defines.h" -#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" #include "mlx/backend/metal/kernels/utils.h" using namespace metal; -using namespace mlx::steel; - -template < - typename T, - short BROWS, - short BCOLS, - short dst_ld, - short reduction_dim, - short tgp_size, - short alignment = 1, - short n_reads = (BCOLS * BROWS) / (tgp_size), - short TCOLS = BCOLS / n_reads, - short TROWS = tgp_size / TCOLS> -struct BlockLoaderFA { - STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; - STEEL_CONST short vec_size = n_reads; - - // Leading dimension for src - const int src_ld; - const int tile_stride; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - const device T* src; - - struct alignas(alignment * sizeof(T)) ReadVector { - uint8_t v[sizeof(T) * vec_size]; - }; - - /* Constructor */ - METAL_FUNC BlockLoaderFA( - const device T* src_, - const int src_ld_, - threadgroup T* dst_, - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(src_ld_), - tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / TCOLS), - bj(vec_size * (thread_idx % TCOLS)), - dst(dst_ + bi * dst_ld + bj), - src(src_ + bi * src_ld + bj) {} - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - *((threadgroup ReadVector*)(&dst[i * dst_ld])) = - *((const device ReadVector*)(&src[i * src_ld])); - } - } - - /* Load from device memory into threadgroup memory - with bound checking */ - METAL_FUNC void load_safe(short2 src_tile_dim) const { - src_tile_dim = src_tile_dim - short2(bj, bi); - - // Skip loading if thread has no valid reads - if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = T(0); - } - } - return; - } - - // Use fast thread memory for bound checks - bool tmp_idx[vec_size]; - T tmp_val[vec_size]; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - // Make sure tmp_idx only contains valid indices - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); - } - - // Read valid indices into tmp_val - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; - } - - // Zero out uneeded values - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); - } - - // Copy values to threadgroup memory - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = tmp_val[j]; - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - src += tile_stride; - } - METAL_FUNC void next(short n) { - src += n * tile_stride; - } -}; - -template -struct LoopAlignment {}; - -template < - typename T, - typename U, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - short lda_tgp, - short ldb_tgp, - typename AccumType = float, - typename Epilogue = TransformNone> -struct BlockMMAFA { - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TM_stride = 8 * WM; - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TN_stride = 8 * WN; - - // Warp tile size along M - STEEL_CONST short TM = BM / TM_stride; - // Warp tile size along N - STEEL_CONST short TN = BN / TN_stride; - - // Strides of A, B along reduction axis - STEEL_CONST short simd_stride_a = { - transpose_a ? TM_stride : TM_stride * lda_tgp}; - STEEL_CONST short simd_stride_b = { - transpose_b ? TN_stride * ldb_tgp : TN_stride}; - - // Jump between elements - STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; - STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; - - STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; - STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; - - // Simdgroup matrices - simdgroup_matrix Asimd[TM]; - simdgroup_matrix Bsimd[TN]; - simdgroup_matrix results[TM * TN] = { - simdgroup_matrix(0)}; - - // Offsets within threadgroup - const short tm; - const short tn; - - short sm; - short sn; - - ushort sid; - ushort slid; - - short As_offset; - short Bs_offset; - - /* Constructor */ - METAL_FUNC BlockMMAFA( - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { - // Determine thread position in simdgroup matrix - short qid = simd_lane_id / 4; - slid = simd_lane_id; - sid = simd_group_id; - - sm = (qid & 4) + (simd_lane_id / 2) % 4; - sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; - - // Determine thread and simdgroup offset - As_offset = - transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); - Bs_offset = - transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); - } - - /* (BM, BK) X (BK, BN) multiply accumulate function */ - METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { - // Adjust for simdgroup and thread location - As += As_offset; - Bs += Bs_offset; - - // Iterate over BK in blocks of 8 - STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < BK; kk += 8) { - simdgroup_barrier(mem_flags::mem_none); - - // Load elements from threadgroup A as simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - Asimd[i].thread_elements()[0] = - static_cast(As[i * simd_stride_a + 0]); - Asimd[i].thread_elements()[1] = - static_cast(As[i * simd_stride_a + jump_a]); - } - - simdgroup_barrier(mem_flags::mem_none); - - // Load elements from threadgroup B as simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - Bsimd[j].thread_elements()[0] = - static_cast(Bs[j * simd_stride_b + 0]); - Bsimd[j].thread_elements()[1] = - static_cast(Bs[j * simd_stride_b + jump_b]); - } - - simdgroup_barrier(mem_flags::mem_none); - - // Multiply and accumulate into result simdgroup matrices - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - short j_serp = (i % 2) ? (TN - 1 - j) : j; - - simdgroup_multiply_accumulate( - results[i * TN + j_serp], - Asimd[i], - Bsimd[j_serp], - results[i * TN + j_serp]); - } - } - - // Progress to next simdgroup tile - As += tile_stride_a; - Bs += tile_stride_b; - } - } - - METAL_FUNC void rescale_output(const threadgroup float* Corrections) { - // Loop over all simdgroup tiles - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - short row = sm + tm + i * TM_stride; - float scale_value = Corrections[row]; - - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread auto& accum = results[i * TN + j].thread_elements(); - // int offset = (i * TM_stride) * ldc + (j * TN_stride); - accum[0] *= scale_value; - accum[1] *= scale_value; - } - } - } - - /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result(device U* C, const int ldc) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + tn + sn; - - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); - - // Apply epilogue - U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; - - // Write out C - C[offset] = outs[0]; - C[offset + 1] = outs[1]; - } - } - } - - METAL_FUNC void store_result_to_tgp_memory( - threadgroup U* C, - const int ldc, - short2 dst_tile_dims) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn); - dst_tile_dims -= short2(tn + sn, sm + tm); - - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); - - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - C[offset] = Epilogue::apply(accum[0]); - } - - if (j * TN_stride + 1 < dst_tile_dims.x) { - C[offset + 1] = Epilogue::apply(accum[1]); - } - } - } - } - } - - METAL_FUNC void - store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn); - dst_tile_dims -= short2(tn + sn, sm + tm); - - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset = (i * TM_stride) * ldc + (j * TN_stride); - - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - C[offset] = Epilogue::apply(accum[0]); - } - - if (j * TN_stride + 1 < dst_tile_dims.x) { - C[offset + 1] = Epilogue::apply(accum[1]); - } - } - } - } - } - - /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result( - device U* D, - const int ldd, - const device U* C, - const int ldc, - const int fdc, - thread const Epilogue& epilogue_op) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; - - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - - // Apply epilogue - U outs[2] = { - epilogue_op.apply(accum[0], C[offset_c]), - epilogue_op.apply(accum[1], C[offset_c + fdc])}; - - // Write out D - D[offset_d] = outs[0]; - D[offset_d + 1] = outs[1]; - } - } - } - - METAL_FUNC void store_result_safe( - device U* D, - const int ldd, - const device U* C, - const int ldc, - const int fdc, - short2 dst_tile_dims, - thread const Epilogue& epilogue_op) const { - // Adjust for simdgroup and thread location - C += (sm + tm) * ldc + (tn + sn) * fdc; - D += (sm + tm) * ldd + tn + sn; - dst_tile_dims -= short2(tn + sn, sm + tm); - - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = results[i * TN + j].thread_elements(); - int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - - // Apply epilogue and output C - if (j * TN_stride < dst_tile_dims.x) { - D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); - } - - if (j * TN_stride + 1 < dst_tile_dims.x) { - D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); - } - } - } - } - } - - METAL_FUNC void clear_results() { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - results[i * TN + j] = simdgroup_matrix(0); - } - } - } -}; - -template < - typename T, - typename U, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_q, - bool transpose_k, - bool transpose_v, - bool MN_aligned, - bool K_aligned, - typename AccumType = typename AccumHelper::accum_type, - typename Epilogue = TransformNone> -struct FastAttentionKernel { - STEEL_CONST short tgp_padding = 16 / sizeof(T); - STEEL_CONST short float_padding = 16 / sizeof(float); - STEEL_CONST short tgp_mem_size_q = - transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_k = - transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_v = - transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); - STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding); - - // maxes, rowsums, rescale - STEEL_CONST short tgp_mem_size_corrections = - 4 * (BM * sizeof(float) + float_padding); - - STEEL_CONST bool share_kv_smem = transpose_k != transpose_v; - - STEEL_CONST short tgp_mem_size = share_kv_smem - ? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + - tgp_mem_size_corrections - : tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + - tgp_mem_size_corrections + tgp_mem_size_v; - - STEEL_CONST short tgp_size = WM * WN * 32; - - static_assert(transpose_q == false, "Expected Q not transposed."); - static_assert(transpose_k == true, "Expected K transposed."); - static_assert(transpose_v == false, "Expected V not transposed."); - static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested."); - - using loader_q_t = BlockLoaderFA< - T, - transpose_q ? BK : BM, - transpose_q ? BM : BK, - transpose_q ? BM + tgp_padding : BK + tgp_padding, - !transpose_q, - tgp_size>; - - using loader_k_t = BlockLoaderFA< - T, - transpose_k ? BN : BK, - transpose_k ? BK : BN, - transpose_k ? BK + tgp_padding : BN + tgp_padding, - transpose_k, - tgp_size>; - - using loader_v_t = BlockLoaderFA< - T, - transpose_v ? BK : BN, - transpose_v ? BN : BK, - transpose_v ? BN + tgp_padding : BK + tgp_padding, - transpose_v, - tgp_size>; - - using mma_qk_t = BlockMMAFA< - T, - U, - BM, - BN, - BK, - WM, - WN, - transpose_q, - transpose_k, - transpose_q ? BM + tgp_padding : BK + tgp_padding, - transpose_k ? BK + tgp_padding : BN + tgp_padding, - AccumType, - Epilogue>; - - using mma_sv_t = BlockMMAFA< - T, - U, - BM, - BK, - BN, - WM, - WN, - false, - transpose_v, - BN + tgp_padding, - BK + tgp_padding, - AccumType, - Epilogue>; - - /* Main kernel function */ - template - static METAL_FUNC void gemm_loop( - threadgroup T* As [[threadgroup(0)]], - threadgroup T* Bs [[threadgroup(1)]], - const int gemm_k_iterations, - thread loader_k_t& loader_b, - thread mma_qk_t& mma_op, - thread const short& tgp_bm, - thread const short& tgp_bn, - LoopAlignment l = {}) { - // Appease the compiler - (void)l; - (void)tgp_bm; - - short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK); - - // not valid for gemm_k_iterations > 1 (so, BK == d_k) - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (N_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe(tile_dims_B); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - } - } - - static METAL_FUNC void initialize_corrections( - threadgroup float* C, - uint simd_lane_id, - uint simd_group_id) { - if (simd_group_id == 0) { - threadgroup float* maxes = C; - threadgroup float* sums = C + (BM + float_padding); - threadgroup float* o_rescale = sums + (BM + float_padding); - threadgroup float* output_rescale = o_rescale + (BM + float_padding); - - if (simd_lane_id < BM) { - maxes[simd_lane_id] = -INFINITY; // m_i - sums[simd_lane_id] = 0.f; // l_i - o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new) - output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i - } - } - } - - static METAL_FUNC void rescale_ss( - threadgroup T* Ss, - threadgroup float* Corrections, - uint simd_group_id, - uint simd_lane_id, - short2 local_blocks, - float alpha) { - if (simd_group_id == 0) { - short row_offset = BM + float_padding; - threadgroup float* maxes = Corrections; - threadgroup float* sums = Corrections + row_offset; - threadgroup float* o_rescale = sums + row_offset; - threadgroup float* output_scales = o_rescale + row_offset; - - if (simd_lane_id < uint(local_blocks.y)) { - float m_i_old = maxes[simd_lane_id]; - float l_i_old = sums[simd_lane_id]; - - float m_i_new = m_i_old; - float l_i_new = l_i_old; - - short offset = simd_lane_id * (BN + tgp_padding); - - float m_ij = -INFINITY; - - for (short j = 0; j < local_blocks.x; j++) { - float val = alpha * float(Ss[offset + j]); - m_ij = max(m_ij, val); - } - - m_i_new = max(m_ij, m_i_new); - - float rowsum = 0.f; // lij - - for (short j = 0; j < local_blocks.x; j++) { - float val = alpha * float(Ss[offset + j]); - float P_i_j = exp(val - m_ij); - rowsum += P_i_j; - P_i_j = P_i_j * exp(m_ij - m_i_new); - Ss[offset + j] = T(P_i_j); - } - - l_i_new = - exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum; - maxes[simd_lane_id] = m_i_new; - sums[simd_lane_id] = l_i_new; - float rescale = l_i_old * exp(m_i_old - m_i_new); - o_rescale[simd_lane_id] = rescale; - output_scales[simd_lane_id] = 1.0 / l_i_new; - } - } - } - - /* Main kernel function */ - static METAL_FUNC void run( - const device T* Q [[buffer(0)]], - const device T* K [[buffer(1)]], - const device T* V [[buffer(2)]], - device U* O [[buffer(3)]], - const constant MLXFastAttentionParams* params [[buffer(4)]], - threadgroup T* Qs [[threadgroup(0)]], - threadgroup T* Ks [[threadgroup(1)]], - threadgroup T* Ss [[threadgroup(2)]], - threadgroup T* Vs [[threadgroup(3)]], - threadgroup float* Corrections [[threadgroup(4)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - // Pacifying compiler - (void)lid; - - const int tid_y = ((tid.y) << params->swizzle_log) + - ((tid.x) & ((1 << params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> params->swizzle_log; - - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; - } - - threadgroup_barrier(mem_flags::mem_none); - - // Find block in Q, O; and head in K, V. - const int c_row = tid_y * BM; - - Q += transpose_q ? c_row : c_row * params->ldq; - thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id); - - short tgp_bm = min(BM, params->M - c_row); - short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm); - - loader_q.load_safe(tile_dims_Q); - - initialize_corrections(Corrections, simd_lane_id, simd_group_id); - - O += c_row * params->ldo; - - // Prepare threadgroup mma operation - thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id); - thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id); - thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id); - thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id); - - for (short n_block = 0; n_block < params->gemm_n_iterations_aligned; - n_block++) { - short c_col = BN; - - // Prepare threadgroup loading operations - short gemm_k_iterations = params->gemm_k_iterations_aligned; - short tgp_bn_qk = min(BN, params->N - c_col * n_block); - threadgroup_barrier(mem_flags::mem_none); - - /////////////////////////////////////////////////////////////////////////////// - { // Loop over K - unaligned case - - if (tgp_bm == BM && tgp_bn_qk == BN) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - } else if (tgp_bn_qk == BN) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - - } else if (tgp_bm == BM) { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - - } else { - gemm_loop( - Qs, - Ks, - gemm_k_iterations, - loader_k, - mma_qk_op, - tgp_bm, - tgp_bn_qk); - } - } - - mma_qk_op.store_result_to_tgp_memory( - Ss, BN + tgp_padding, short2(BN, BM)); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - rescale_ss( - Ss, - Corrections, - simd_group_id, - simd_lane_id, - short2(tgp_bn_qk, tgp_bm), - params->alpha); - - loader_v.load_safe(short2(BK, tgp_bn_qk)); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - threadgroup float* o_scales = Corrections + 2 * (BM + float_padding); - mma_softmax_sv_op.rescale_output(o_scales); - - mma_softmax_sv_op.mma(Ss, Vs); - - threadgroup float* final_output_scales = - Corrections + 3 * (BM + float_padding); - - mma_softmax_sv_op.rescale_output(final_output_scales); - - loader_v.next(); - loader_k.next(BN); - - mma_qk_op.clear_results(); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm)); - } -}; - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_q, - bool transpose_k, - bool transpose_v, - bool MN_aligned, - bool K_aligned> -[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention( - const device T* Q [[buffer(0)]], - const device T* K [[buffer(1)]], - const device T* V [[buffer(2)]], - device T* O [[buffer(3)]], - const constant MLXFastAttentionParams* params [[buffer(4)]], - const constant int* batch_shape [[buffer(6)]], - const constant size_t* batch_strides [[buffer(7)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - using attention_kernel = FastAttentionKernel< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_q, - transpose_k, - transpose_v, - MN_aligned, - K_aligned>; - - // Adjust for batch - if (params->batch_ndim > 1) { - const constant size_t* Q_bstrides = batch_strides; - const constant size_t* KV_bstrides = batch_strides + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim); - - Q += batch_offsets.x; - K += batch_offsets.y; - V += batch_offsets.y; - - } else { - Q += params->batch_stride_q * tid.z; - K += params->batch_stride_k * tid.z; - V += params->batch_stride_v * tid.z; - } - - // same shape as input - O += params->batch_stride_o * tid.z; - threadgroup T Qs[attention_kernel::tgp_mem_size_q]; - threadgroup T Ss[attention_kernel::tgp_mem_size_s]; - threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections]; - - if (attention_kernel::share_kv_smem) { - threadgroup T Ks[attention_kernel::tgp_mem_size_k]; - threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v]; - attention_kernel::run( - Q, - K, - V, - O, - params, - Qs, - Ks, - Ss, - Vs, - Corrections, - simd_lane_id, - simd_group_id, - tid, - lid); - } else { - threadgroup T Ks[attention_kernel::tgp_mem_size_k]; - threadgroup T Vs[attention_kernel::tgp_mem_size_v]; - attention_kernel::run( - Q, - K, - V, - O, - params, - Qs, - Ks, - Ss, - Vs, - Corrections, - simd_lane_id, - simd_group_id, - tid, - lid); - } -} - // clang-format off - -// SDPA full instantiations -#define instantiate_fast_inference_self_attention_kernel( \ - itype, otype, bm, bn, bk, wm, wn) \ - template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \ - "_itype_" #itype)]] [[kernel]] void \ - attention( \ - const device itype* Q [[buffer(0)]], \ - const device itype* K [[buffer(1)]], \ - const device itype* V [[buffer(2)]], \ - device otype* O [[buffer(3)]], \ - const constant MLXFastAttentionParams* params [[buffer(4)]], \ - const constant int* batch_shape [[buffer(6)]], \ - const constant size_t* batch_strides [[buffer(7)]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]]); - -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 64, - 2, - 2); -instantiate_fast_inference_self_attention_kernel( - float, - float, - 16, - 16, - 128, - 2, - 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2); -instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2); - // SDPA vector instantiations #define instantiate_sdpa_vector(type, head_dim) \ instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim) \ diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention_params.h b/mlx/backend/metal/kernels/scaled_dot_product_attention_params.h deleted file mode 100644 index a77dad268b..0000000000 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention_params.h +++ /dev/null @@ -1,42 +0,0 @@ -// -// scaled_dot_product_attention_params.h -// mlx - -#pragma once - -struct MLXFastAttentionParams { - const int M; - const int N; - const int K; - - const int ldq; // ldq == ldo - const int ldk; - const int ldv; - const int lds; - const int ldo; - - const int tiles_n; - const int tiles_m; - - const int batch_stride_q; - const int batch_stride_k; - const int batch_stride_v; - const int batch_stride_o; - - const int swizzle_log; - const int gemm_n_iterations_aligned; - const int gemm_k_iterations_aligned; - const int gemm_sv_m_block_iterations; - - const int batch_ndim; - const float alpha; -}; - -struct MLXScaledDotProductAttentionParams { - // Associated dimensions & transposition information - const uint QUERY_SEQUENCE_LENGTH = 1; - const uint N_Q_HEADS = 32; - const uint N_KV_HEADS = 32; - const uint KV_TILES = 1; - const float INV_ALPHA = 0.08838834764831843f; -}; diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index eadf2a6633..911984ef42 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -6,9 +6,8 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" -#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" + #include "mlx/backend/metal/kernels/steel/attn/params.h" -#include "mlx/backend/metal/kernels/steel/gemm/params.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" #include "mlx/utils.h" @@ -326,7 +325,23 @@ void ScaledDotProductAttention::eval_gpu( const auto& q = copy_unless(is_matrix_contiguous, q_pre); const auto& k = copy_unless(is_matrix_contiguous, k_pre); const auto& v = copy_unless(is_matrix_contiguous, v_pre); - o.set_data(allocator::malloc_or_wait(o.nbytes())); + + size_t str_oD = 1; + size_t str_oH = o.shape(3); + size_t str_oL = o.shape(1) * str_oH; + size_t str_oB = o.shape(2) * str_oL; + + array::Flags flags{ + /* bool contiguous = */ 1, + /* bool row_contiguous = */ 0, + /* bool col_contiguous = */ 0, + }; + + o.set_data( + allocator::malloc_or_wait(o.nbytes()), + o.data_size(), + {str_oB, str_oH, str_oL, str_oD}, + flags); sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o); }