From 168a3a464a541cbaacb84353def4c68372257b77 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Tue, 19 Nov 2024 10:24:52 -0800 Subject: [PATCH] [WIP]: Loading and Matmuls added --- .../steel/attn/kernels/steel_attention.h | 160 +++++++++++++++++- .../steel/attn/kernels/steel_attention.metal | 7 +- mlx/backend/metal/kernels/steel/attn/loader.h | 127 ++++++++++++++ mlx/backend/metal/kernels/steel/attn/mma.h | 8 +- mlx/backend/metal/kernels/steel/gemm/mma.h | 4 +- .../metal/scaled_dot_product_attention.cpp | 6 +- 6 files changed, 295 insertions(+), 17 deletions(-) diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index 61d1b9563..d3872fc8c 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -19,6 +19,16 @@ constant bool do_gather [[function_constant(300)]]; constant bool gather_bias = do_gather && use_out_source; +template +struct TransformScale { + T scale; + METAL_FUNC TransformScale(T scale_) : scale(scale_) {} + + METAL_FUNC T apply(T x) const { + return scale * x; + } +}; + // clang-format off template < typename T, @@ -55,11 +65,151 @@ template < tidl.y * params->O_strides[1] + // Head tidl.x * BQ * params->O_strides[2]; // Seqeunce - for (int i = simd_group_id * 32 + simd_lane_id; i < BQ * params->D; - i += WM * WN * 32) { - int r = i / params->D; - int c = i % params->D; + constexpr int padQ = 0; // 16 / sizeof(T); + constexpr int padK = 0; // 16 / sizeof(T); + constexpr int padV = 0; // 16 / sizeof(T); - O[params->O_strides[2] * r + c] = T(0); + // using QBlockSrcShape = CShape; + // using KBlockSrcShape = CShape; + // using VBlockSrcShape = CShape; + + constexpr int LDQ_tgp = BD + padQ; + constexpr int LDK_tgp = BK + padK; + constexpr int LDV_tgp = BD + padV; + + // using QBlockDstStrides = CShape; + // using KBlockDstStrides = CShape<1, LDK_tgp>; + // using QBlockDstStrides = CShape; + + threadgroup T Qs[BQ * (BD + padQ)]; + threadgroup T Ks[(BK + padK) * BD]; + threadgroup T Vs[BK * (BD + padV)]; + + using QBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BQ, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ LDQ_tgp, + /* short kDstStrCol = */ 1, + /* short reduction_dim = */ 1, + /* short tgp_size = */ WM * WN * 32>; + + using KBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BK, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ 1, + /* short kDstStrCol = */ LDK_tgp, + /* short reduction_dim = */ 0, + /* short tgp_size = */ WM * WN * 32>; + + using VBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BK, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ LDV_tgp, + /* short kDstStrCol = */ 1, + /* short reduction_dim = */ 0, + /* short tgp_size = */ WM * WN * 32>; + + QBlockLoader loader_q( + Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id); + KBlockLoader loader_k( + K, params->K_strides[2], Ks, simd_group_id, simd_lane_id); + VBlockLoader loader_v( + V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); + + TransformScale ts(static_cast(params->scale)); + + // MMAFrag size + constexpr short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + + constexpr int kNWarps = WM * WN; + static_assert( + BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0, + "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); + + constexpr int TQ = BQ / (kNWarps * kFragSize); // Q seq frags per warp + constexpr int TK = + BK / kFragSize; // KV sequence frags (all warps load the same frags) + constexpr int TD = + BD / kFragSize; // HeadDim frags (all warps load the same frags) + + static_assert(TQ == 1, "Check TQ"); + + MMATile Qtile; + MMATile Ktile; + MMATile Stile; + MMATile Vtile; + MMATile Otile; + + Otile.clear(); + + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + short sm = simd_coord.y; + short sn = simd_coord.x; + short tm = kFragSize * TQ * simd_group_id; + + short Qs_offset = (tm + sm) * LDQ_tgp + sn; + short Ks_offset = sm * LDK_tgp + sn; + short Vs_offset = sm * LDV_tgp + sn; + + constexpr int Qs_tile_stride = kFragSize; + constexpr int Ks_tile_stride = kFragSize * LDK_tgp; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + loader_q.load_unsafe(); + loader_q.apply_inplace_op(ts); + + for (int kb = 0; kb < params->NK; kb++) { + // Load Q and K blocks and apply scale + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_k.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do S = Q @ K.T + Stile.clear(); + + for (short dd = 0; dd < TD; dd++) { + simdgroup_barrier(mem_flags::mem_none); + + Qtile.template load( + &Qs[Qs_offset + dd * Qs_tile_stride]); + Ktile.template load( + &Ks[Ks_offset + dd * Ks_tile_stride]); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Stile, Qtile, Ktile, Stile); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Load V blocks + loader_v.load_unsafe(); + + // Do softmax + + // Do O = S @ V + threadgroup_barrier(mem_flags::mem_threadgroup); + Vtile.template load(&Vs[Vs_offset]); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Otile, Stile, Vtile, Otile); + + // Prepare for next iteration + // loader_q.next(); + loader_k.next(); + loader_v.next(); } + + threadgroup_barrier(mem_flags::mem_none); + + // Store results + O += (tm + sm) * params->O_strides[2] + sn; + Otile.template store(O, params->O_strides[2]); } diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal index bbfecd7bf..f509f1b1c 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal @@ -1,7 +1,7 @@ // Copyright © 2024 Apple Inc. // clang-format off -#include "mlx/backend/metal/kernels/bf16.h" +// #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/steel/attn/attn.h" @@ -21,11 +21,12 @@ uint3 lid [[thread_position_in_threadgroup]]); #define instantiate_attn_shapes_helper(iname, itype) \ - instantiate_attn(iname, itype, 16, 16, 64, 2, 2) \ + instantiate_attn(iname, itype, 32, 32, 64, 4, 1) \ + instantiate_attn(iname, itype, 32, 16, 64, 4, 1) \ instantiate_attn_shapes_helper(float16, half); -instantiate_attn_shapes_helper(bfloat16, bfloat16_t); +// instantiate_attn_shapes_helper(bfloat16, bfloat16_t); instantiate_attn_shapes_helper(float32, float); // clang-format on diff --git a/mlx/backend/metal/kernels/steel/attn/loader.h b/mlx/backend/metal/kernels/steel/attn/loader.h index 3f084d8ec..2849c00f1 100644 --- a/mlx/backend/metal/kernels/steel/attn/loader.h +++ b/mlx/backend/metal/kernels/steel/attn/loader.h @@ -133,5 +133,132 @@ struct BlockLoader { } }; +template +struct CShape { + STEEL_CONST int kRows = R; + STEEL_CONST int kCols = C; +}; + +template < + typename T, + short BROWS, + short BCOLS, + short kDstStrRow, + short kDstStrCol, + short reduction_dim, + short tgp_size, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoaderT { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + /* Constructor */ + METAL_FUNC BlockLoaderT( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * kDstStrRow + bj * kDstStrCol), + src(src_ + bi * src_ld + bj) {} + + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = + op.apply(dst[i * kDstStrRow + j * kDstStrCol]); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j]; + } + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + } // namespace steel } // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/attn/mma.h b/mlx/backend/metal/kernels/steel/attn/mma.h index ad7a8c529..c784efb00 100644 --- a/mlx/backend/metal/kernels/steel/attn/mma.h +++ b/mlx/backend/metal/kernels/steel/attn/mma.h @@ -358,12 +358,12 @@ METAL_FUNC void tile_matmad( thread MMATile& B, thread MMATile& C) { STEEL_PRAGMA_UNROLL - for (short m = 0; m < M; ++m) { + for (short k = 0; k < K; ++k) { STEEL_PRAGMA_UNROLL - for (short n = 0; n < N; ++n) { - short n_serp = (m % 2) ? (N - 1 - n) : n; + for (short m = 0; m < M; ++m) { STEEL_PRAGMA_UNROLL - for (short k = 0; k < K; ++k) { + for (short n = 0; n < N; ++n) { + short n_serp = (m % 2) ? (N - 1 - n) : n; MMATile::MMAFrag_t::mma( D.frag_at(m, n_serp), A.frag_at(m, k), diff --git a/mlx/backend/metal/kernels/steel/gemm/mma.h b/mlx/backend/metal/kernels/steel/gemm/mma.h index 7ad6c0fe6..aea235abb 100644 --- a/mlx/backend/metal/kernels/steel/gemm/mma.h +++ b/mlx/backend/metal/kernels/steel/gemm/mma.h @@ -385,9 +385,9 @@ struct BlockMMA { STEEL_CONST short TN_stride = kFragSize * WN; // Warp tile size along M - STEEL_CONST short TM = BM / TM_stride; + STEEL_CONST short TM = BM / (kFragSize * WM); // Warp tile size along N - STEEL_CONST short TN = BN / TN_stride; + STEEL_CONST short TN = BN / (kFragSize * WN); // Threadgroup A strides STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index cfb57da93..8599afbed 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -26,10 +26,10 @@ void sdpa_full_self_attention_metal( array& o) { using namespace mlx::steel; - int wm = 2; - int wn = 2; + int wm = 4; + int wn = 1; - int bq = 16; + int bq = 32; int bk = 16; int bd = 64;