diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index 92fd547ca..58c0866be 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -91,6 +91,10 @@ template < uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + // Pacifying compiler + (void)lid; + + // Move to correct block ulong3 tidl{tid.x, tid.y, tid.z}; Q += tidl.z * params->Q_strides[0] + // Batch @@ -107,26 +111,20 @@ template < tidl.y * params->O_strides[1] + // Head tidl.x * BQ * params->O_strides[2]; // Seqeunce - constexpr int padQ = 0; // 16 / sizeof(T); - constexpr int padK = 0; // 16 / sizeof(T); - constexpr int padV = 0; // 16 / sizeof(T); + // Prepare threadgroup memory + constexpr short padQ = 0; // 16 / sizeof(T); + constexpr short padK = 0; // 16 / sizeof(T); + constexpr short padV = 0; // 16 / sizeof(T); - // using QBlockSrcShape = CShape; - // using KBlockSrcShape = CShape; - // using VBlockSrcShape = CShape; - - constexpr int LDQ_tgp = BD + padQ; - constexpr int LDK_tgp = BK + padK; - constexpr int LDV_tgp = BD + padV; - - // using QBlockDstStrides = CShape; - // using KBlockDstStrides = CShape<1, LDK_tgp>; - // using QBlockDstStrides = CShape; + constexpr short LDQ_tgp = BD + padQ; + constexpr short LDK_tgp = BK + padK; + constexpr short LDV_tgp = BD + padV; threadgroup T Qs[BQ * (BD + padQ)]; threadgroup T Ks[(BK + padK) * BD]; threadgroup T Vs[BK * (BD + padV)]; + // Prepare block loaders using QBlockLoader = BlockLoaderT< /* typename T = */ T, /* short BROWS = */ BQ, @@ -136,6 +134,7 @@ template < /* short reduction_dim = */ 1, /* short tgp_size = */ WM * WN * 32>; + // K is loaded in transposed using KBlockLoader = BlockLoaderT< /* typename T = */ T, /* short BROWS = */ BK, @@ -163,8 +162,8 @@ template < TransformScale ts(static_cast(params->scale)); - // MMAFrag size - constexpr short kFragSize = 8; + // Prepare MMA tiles + constexpr short kFragSize = 8; // MMAFrag size using MMAFrag_acc_t = BaseMMAFrag; constexpr int kNWarps = WM * WN; @@ -189,35 +188,40 @@ template < Otile.clear(); - short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); - short sm = simd_coord.y; - short sn = simd_coord.x; - short tm = kFragSize * TQ * simd_group_id; + // Prepare mma tile offsets + const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + const short sm = simd_coord.y; + const short sn = simd_coord.x; + const short tm = kFragSize * TQ * simd_group_id; - short Qs_offset = (tm + sm) * LDQ_tgp + sn; - short Ks_offset = sm * LDK_tgp + sn; - short Vs_offset = sm * LDV_tgp + sn; + const short Qs_offset = (tm + sm) * LDQ_tgp + sn; + const short Ks_offset = sm * LDK_tgp + sn; + const short Vs_offset = sm * LDV_tgp + sn; - constexpr int Qs_tile_stride = kFragSize; - constexpr int Ks_tile_stride = kFragSize * LDK_tgp; + constexpr short Qs_tile_stride = kFragSize; + constexpr short Ks_tile_stride = kFragSize * LDK_tgp; threadgroup_barrier(mem_flags::mem_threadgroup); + // Load Q blocks apply scale loader_q.load_unsafe(); loader_q.apply_inplace_op(ts); - constexpr int kRowsPT = decltype(Stile)::kRowsPerThread; + // Init row reduction variables + constexpr short kRowsPT = decltype(Stile)::kRowsPerThread; AccumType max_score[kRowsPT]; AccumType sum_score[kRowsPT] = {0}; + // Init to -Inf STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { max_score[i] = Limits::min; } + // Loop over KV seq length for (int kb = 0; kb < params->NK; kb++) { - // Load Q and K blocks and apply scale + // Load K block and apply scale threadgroup_barrier(mem_flags::mem_threadgroup); loader_k.load_unsafe(); @@ -246,15 +250,15 @@ template < // Do softmax - // Row max + // Temp variables AccumType new_max[kRowsPT]; AccumType factor[kRowsPT]; - STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { new_max[i] = max_score[i]; } + // Row max Stile.template row_reduce(new_max); // exp(Si - rowmax(Si)) @@ -265,6 +269,8 @@ template < for (short i = 0; i < kRowsPT; ++i) { factor[i] = fast::exp(max_score[i] - new_max[i]); } + + // Save max for next iteration STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { max_score[i] = new_max[i]; @@ -283,20 +289,21 @@ template < // Update O Otile.template row_bin_op(factor); - // Do O = S @ V + // Load V into registers threadgroup_barrier(mem_flags::mem_threadgroup); Vtile.template load(&Vs[Vs_offset]); simdgroup_barrier(mem_flags::mem_none); + // Do O = S @ V tile_matmad(Otile, Stile, Vtile, Otile); // Prepare for next iteration - // loader_q.next(); loader_k.next(); loader_v.next(); } + // Normalize output Otile.template row_bin_op(sum_score); threadgroup_barrier(mem_flags::mem_none); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index c90b5f5dd..8ea8f1ea2 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -600,7 +600,7 @@ array scaled_dot_product_attention( * * dtype is not fp32 or fp16 */ - int threshold = 1024; // TODO: Fix after dev + int threshold = 32; // TODO: Fix after dev if (memory_efficient_threshold.has_value()) { threshold = std::max(1, memory_efficient_threshold.value()); }