diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index d3872fc8c..92fd547ca 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -29,6 +29,48 @@ struct TransformScale { } }; +struct MaxOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return metal::max(x, y); + } +}; + +struct SumOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x + y; + } +}; + +struct MulOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x * y; + } +}; + +struct SubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x - y; + } +}; + +struct ExpSubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return fast::exp(x - y); + } +}; + +struct DivOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x / y; + } +}; + // clang-format off template < typename T, @@ -130,11 +172,12 @@ template < BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0, "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); - constexpr int TQ = BQ / (kNWarps * kFragSize); // Q seq frags per warp - constexpr int TK = - BK / kFragSize; // KV sequence frags (all warps load the same frags) - constexpr int TD = - BD / kFragSize; // HeadDim frags (all warps load the same frags) + // Q seq frags per warp + constexpr int TQ = BQ / (kNWarps * kFragSize); + // KV sequence frags (all warps load the same frags) + constexpr int TK = BK / kFragSize; + // HeadDim frags (all warps load the same frags) + constexpr int TD = BD / kFragSize; static_assert(TQ == 1, "Check TQ"); @@ -163,6 +206,16 @@ template < loader_q.load_unsafe(); loader_q.apply_inplace_op(ts); + constexpr int kRowsPT = decltype(Stile)::kRowsPerThread; + + AccumType max_score[kRowsPT]; + AccumType sum_score[kRowsPT] = {0}; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = Limits::min; + } + for (int kb = 0; kb < params->NK; kb++) { // Load Q and K blocks and apply scale threadgroup_barrier(mem_flags::mem_threadgroup); @@ -193,6 +246,43 @@ template < // Do softmax + // Row max + AccumType new_max[kRowsPT]; + AccumType factor[kRowsPT]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + new_max[i] = max_score[i]; + } + + Stile.template row_reduce(new_max); + + // exp(Si - rowmax(Si)) + Stile.template row_bin_op(new_max); + + // Factor exp(rowmax(Si) - rowmax(Si-1)) + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + factor[i] = fast::exp(max_score[i] - new_max[i]); + } + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = new_max[i]; + } + + // Row Sum + AccumType sum_score_tmp[kRowsPT] = {0}; + Stile.template row_reduce(sum_score_tmp); + + // Update norm + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i]; + } + + // Update O + Otile.template row_bin_op(factor); + // Do O = S @ V threadgroup_barrier(mem_flags::mem_threadgroup); Vtile.template load(&Vs[Vs_offset]); @@ -207,6 +297,7 @@ template < loader_v.next(); } + Otile.template row_bin_op(sum_score); threadgroup_barrier(mem_flags::mem_none); // Store results diff --git a/mlx/backend/metal/kernels/steel/attn/mma.h b/mlx/backend/metal/kernels/steel/attn/mma.h index c784efb00..5ddd37ac6 100644 --- a/mlx/backend/metal/kernels/steel/attn/mma.h +++ b/mlx/backend/metal/kernels/steel/attn/mma.h @@ -59,6 +59,8 @@ struct BaseMMAFrag { typedef metal::simdgroup_matrix mat_type; typedef metal::vec frag_type; + typedef metal::vec row_frag_type; + typedef metal::vec col_frag_type; METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id [[thread_index_in_simdgroup]]) { @@ -182,6 +184,35 @@ struct BaseMMAFrag { thread mat_type& C) { simdgroup_multiply_accumulate(D, A, B, C); } + + template + METAL_FUNC static constexpr void row_reduce( + thread const frag_type& inp_vals, + thread T* reduced_vals) { + T thr_reduce = Op::apply(inp_vals.x, inp_vals.y); + + T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); + qgr_reduce = Op::apply(thr_reduce, qgr_reduce); + + T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); + sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); + + reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce); + } + + template + METAL_FUNC static constexpr void row_bin_op( + thread frag_type& inp_vals, + thread T* row_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + inp_vals[i * kElemCols + j] = + Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); + } + } + } }; template < @@ -205,6 +236,9 @@ struct MMATile { STEEL_CONST int kNumFrags = kTileRows * kTileCols; STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; + STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows; + STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols; + typedef typename MMAFrag_t::mat_type mat_type; typedef typename MMAFrag_t::frag_type frag_type; @@ -246,6 +280,30 @@ struct MMATile { return reinterpret_cast(val_frags); } + template + METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::template row_reduce( + frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::template row_bin_op( + frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); + } + } + } + template METAL_FUNC void load(const threadgroup U* src) { STEEL_PRAGMA_UNROLL