From 210400d1120e142cfb0abd60a607c70740018117 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Fri, 14 Nov 2025 15:57:10 -0800 Subject: [PATCH] [WIP] Init NAX attention --- mlx/backend/metal/kernels/CMakeLists.txt | 5 + .../steel/attn/kernels/steel_attention_nax.h | 475 ++++++++ .../attn/kernels/steel_attention_nax.metal | 33 + mlx/backend/metal/kernels/steel/attn/nax.h | 1079 +++++++++++++++++ .../metal/scaled_dot_product_attention.cpp | 157 +++ 5 files changed, 1749 insertions(+) create mode 100644 mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h create mode 100644 mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal create mode 100644 mlx/backend/metal/kernels/steel/attn/nax.h diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 3f04f5086..d921e557f 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -137,6 +137,11 @@ if(MLX_ENABLE_NAX) build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS}) build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS}) + set(STEEL_NAX_ATTN_HEADERS + steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h + steel/utils/integral_constant.h) + + build_kernel(steel/attn/kernels/steel_attention_nax ${STEEL_NAX_ATTN_HEADERS}) endif() add_custom_command( diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h new file mode 100644 index 000000000..e067ed005 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h @@ -0,0 +1,475 @@ +// Copyright © 2024-25 Apple Inc. + +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool align_Q [[function_constant(200)]]; +constant bool align_K [[function_constant(201)]]; + +constant bool has_mask [[function_constant(300)]]; +constant bool do_causal [[function_constant(301)]]; +constant bool has_sinks [[function_constant(302)]]; + +template +struct TransformScale { + T scale; + METAL_FUNC TransformScale(T scale_) : scale(scale_) {} + + METAL_FUNC T apply(T x) const { + return scale * x; + } +}; + +struct MaxOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return metal::max(x, y); + } +}; + +struct SumOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x + y; + } +}; + +struct MulOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x * y; + } +}; + +struct SubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x - y; + } +}; + +struct ExpSubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return fast::exp2(x - y); + } +}; + +struct DivOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x / y; + } +}; + +// clang-format off +template < + typename T, + int BQ, + int BK, + int BD, + int WM, + int WN, + typename MaskType = float, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention_nax( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device T* O [[buffer(3)]], + const constant AttnParams* params [[buffer(4)]], + const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], + const device MaskType* mask [[buffer(6), function_constant(has_mask)]], + const device T* sinks [[buffer(7), function_constant(has_sinks)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + + // Pacifying compiler + (void)lid; + + // Move to correct block + ulong3 tidl{tid.x, tid.y, tid.z}; + + Q += tidl.z * params->Q_strides[0] + // Batch + tidl.y * params->Q_strides[1] + // Head + tidl.x * BQ * params->Q_strides[2]; // Sequence + + ulong kv_head_idx = int(tid.y) / params->gqa_factor; + K += tidl.z * params->K_strides[0] + // Batch + kv_head_idx * params->K_strides[1]; // Head + + V += tidl.z * params->V_strides[0] + // Batch + kv_head_idx * params->V_strides[1]; // Head + + O += tidl.z * params->O_strides[0] + // Batch + tidl.y * params->O_strides[1] + // Head + tidl.x * BQ * params->O_strides[2]; // Sequence + + if (has_mask) { + mask += tidl.z * mask_params->M_strides[0] + // Batch + tidl.y * mask_params->M_strides[1]; // Head + } + + const metal::uniform scale2 = + make_uniform(params->scale) * make_uniform(1.44269504089f); + + // Prepare MMA tiles + constexpr short UQ = 16; + constexpr short UD = 32; + + constexpr int kNWarps = WM * WN; + static_assert( + BQ >= (kNWarps * UQ) && BQ % (kNWarps * UQ) == 0, + "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); + + // Q seq frags per warp + constexpr int TQ = BQ / (kNWarps * UQ); + // HeadDim frags (all warps load the same frags) + constexpr int TD = BD / UD; + + static_assert(TQ == 1, "Check TQ"); + + using OSubTile = NAXSubTile; + NAXTile Otile; + + Otile.clear(); + + // Prepare mma tile offsets + const short2 simd_coord = OSubTile::NAXFrag_t::get_coord(); + const short sm = simd_coord.y; + const short sn = simd_coord.x; + const short tm = UQ * TQ * simd_group_id; + + Q += (tm + sm) * int(params->Q_strides[2]) + sn; + K += sm * int(params->K_strides[2]) + sn; + V += sm * int(params->V_strides[2]) + sn; + + // Init row reduction variables + constexpr short kRowsPT = decltype(Otile)::kRowsPerThread; + + metal::vec max_score; + metal::vec sum_score{0}; + + // Init to -Inf + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = Limits::finite_min; + } + + if (has_sinks) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = M_LOG2E_F * static_cast(sinks[tidl.y]); + sum_score[i] = 1; + } + } + + int kb_lim = params->NK; + + if (do_causal) { + int q_max = (tid.x + 1) * BQ + params->qL_off; + kb_lim = (q_max + BK - 1) / BK; + kb_lim = min(params->NK, kb_lim); + } + + const bool is_last_bq = int(tid.x) == (params->NQ_aligned); + const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ); + const bool is_last_q = is_last_bq; + + const short lim_rows_q = params->qL_rem - (tm + sm); + const short lim_rows_k = params->kL_rem - sm; + + // Loop over KV seq length + for (int kb = 0; kb < kb_lim; kb++) { + const int is_last_k = (kb == (params->NK_aligned)); + + // Do S = Q @ K.T + constexpr short UDs = 16; + constexpr short UKs = 32; + + constexpr short TDs = BD / UDs; + constexpr short TKs = BK / UKs; + + using SSubTile = NAXSubTile; + using QSubTile = NAXSubTile; + using KSubTile = NAXSubTile; + + NAXTile Stile; + + Stile.clear(); + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TKs; ik++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TDs; id++) { + NAXTile Qtile; + NAXTile Ktile; + + const int Q_load_off = iq * UQ * int(params->Q_strides[2]) + id * UDs; + const int K_load_off = + ik * UKs * int(params->K_strides[2]) + id * UDs; + + if (!align_Q && is_last_q) { + // Qtile.load_rows( + // Q + Q_load_off, + // int(params->Q_strides[2]), + // lim_rows_q - iq * UQ); + Qtile.load_safe( + Q + Q_load_off, + int(params->Q_strides[2]), + short2(BD, lim_rows_q - iq * UQ)); + } else { + Qtile.load(Q + Q_load_off, int(params->Q_strides[2])); + } + + if (!align_K && is_last_k) { + // Ktile.load_rows( + // K + K_load_off, + // int(params->K_strides[2]), + // lim_rows_k - ik * UKs); + Ktile.load_safe( + K + K_load_off, + int(params->K_strides[2]), + short2(BD, lim_rows_k - ik * UKs)); + } else { + Ktile.load(K + K_load_off, int(params->K_strides[2])); + } + + subtile_matmad_nax( + Stile.subtile_at(iq, ik), + Qtile.subtile_at(0, 0), + metal::false_type{}, + Ktile.subtile_at(0, 0), + metal::true_type{}); + } + } + } + + // Scale S + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) { + Stile.elems()[ii] *= float(scale2); + } + + // Scale and Retile S + constexpr short UK = 16; + constexpr short TK = BK / UK; + using PSubTile = NAXSubTile; + + NAXTile Ptile; + + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) { + Ptile.elems()[ii] = Stile.elems()[ii]; + } + + // Mask out length sequence + if (!align_K && is_last_k) { + constexpr auto neg_inf = Limits::finite_min; + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + const short col_pos = sn + ik * UK; + + thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); + + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) { + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) { + const auto loc = ii * PSubTile::kFragThrCols + jj; + fg[loc] = ((col_pos + jj) >= params->kL_rem) ? neg_inf : fg[loc]; + } + } + } + } + } + + // Mask out if causal + if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) { + constexpr auto neg_inf = Limits::finite_min; + + const int base_row = tid.x * BQ + params->qL_off + tm; + const int base_col = kb * BK; + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + const short row_pos = base_row + iq * UQ; + const short col_pos = base_col + ik * UK; + + thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); + + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) { + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) { + const auto r = row_pos + ii * PSubTile::kFragRowsJump + sm; + const auto c = col_pos + jj + sn; + const auto loc = ii * PSubTile::kFragThrCols + jj; + fg[loc] = (r < c) ? neg_inf : fg[loc]; + } + } + } + } + } + + // Other masking as needed + if (has_mask) { + constexpr auto neg_inf = Limits::finite_min; + + const int base_row = tid.x * BQ + tm; + const int base_col = kb * BK; + + constexpr bool is_bool = is_same_v; + using melem_t = typename metal::conditional_t; + using MSubTile = NAXSubTile; + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + const short row_pos = base_row + iq * UQ + sm; + const short col_pos = base_col + ik * UK + sn; + + MSubTile mfrag; + mfrag.load_safe( + mask, + int(mask_params->M_strides[2]), + Int<1>{}, + params->qL, + params->kL, + row_pos, + col_pos); + + thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); + + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < MSubTile::kElemsPerFrag; jj++) { + if constexpr (is_bool) { + fg[jj] = mfrag.elems()[jj] ? fg[jj] : neg_inf; + } else { + fg[jj] += M_LOG2E_F * AccumType(mfrag.elems()[jj]); + } + } + } + } + } + + // Do softmax + + // Temp variables + metal::vec new_max; + metal::vec factor; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + new_max[i] = max_score[i]; + } + + // Row max + Ptile.template row_reduce(new_max); + + // exp(Si - rowmax(Si)) + Ptile.template row_bin_op(new_max); + + // Factor exp(rowmax(Si) - rowmax(Si-1)) + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + factor[i] = fast::exp2(max_score[i] - new_max[i]); + max_score[i] = new_max[i]; + } + + // Row Sum + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + sum_score[i] = sum_score[i] * factor[i]; + } + + Ptile.template row_reduce(sum_score); + + // Update O + Otile.template row_bin_op(factor); + + simdgroup_barrier(mem_flags::mem_none); + + // Do O = P @ V + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TD; id++) { + if constexpr (BD == 128) { + if (id == 2) { + threadgroup_barrier(mem_flags::mem_none); + } + } + + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + using VSubTile = NAXSubTile; + NAXTile Vtile; + + const int V_load_off = ik * UK * int(params->V_strides[2]) + id * UD; + + if (!align_K && is_last_k) { + // Vtile.load_rows( + // V + V_load_off, + // int(params->V_strides[2]), + // lim_rows_k - ik * UK); + Vtile.load_safe( + V + V_load_off, + int(params->V_strides[2]), + short2(BD, lim_rows_k - ik * UK)); + } else { + Vtile.load(V + V_load_off, int(params->V_strides[2])); + } + + subtile_matmad_nax( + Otile.subtile_at(iq, id), + Ptile.subtile_at(iq, ik), + metal::bool_constant{}, + Vtile.subtile_at(0, 0), + metal::bool_constant{}); + } + } + } + + // Prepare for next iteration + K += BK * int(params->K_strides[2]); + V += BK * int(params->V_strides[2]); + } + + // Normalize output + + threadgroup_barrier(mem_flags::mem_none); + + metal::vec rcp; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + rcp[i] = (1.f / sum_score[i]); + } + + Otile.template row_bin_op(rcp); + + // Store results + O += (tm + sm) * int(params->O_strides[2]) + sn; + + if (!align_Q && is_last_q) { + if (lim_rows_q <= 0) + return; + + // Otile.store_rows(O, params->O_strides[2], lim_rows_q); + Otile.store_safe(O, params->O_strides[2], short2(BD, lim_rows_q)); + } else { + Otile.store(O, int(params->O_strides[2])); + } +} diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal new file mode 100644 index 000000000..1fba9af61 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal @@ -0,0 +1,33 @@ +// Copyright © 2024-25 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" + +#include "mlx/backend/metal/kernels/steel/attn/nax.h" +#include "mlx/backend/metal/kernels/steel/attn/params.h" +#include "mlx/backend/metal/kernels/steel/attn/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h" + +#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \ + instantiate_kernel( \ + "steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \ + "_wm" #wm "_wn" #wn "_mask" #mname, \ + attention_nax, dtype, bq, bk, bd, wm, wn, mtype, float) + +#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ + instantiate_attn(iname, itype, 64, 32, 128, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 64, 32, 64, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 64, 64, 128, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 64, 64, 64, 4, 1, mname, mtype) + +#define instantiate_attn_mask_helper(iname, itype) \ + instantiate_attn_shapes_helper(iname, itype, iname, itype) \ + instantiate_attn_shapes_helper(iname, itype, bool_, bool) + +instantiate_attn_mask_helper(float16, half); +instantiate_attn_mask_helper(bfloat16, bfloat); + +instantiate_attn_mask_helper(float32, float); +// clang-format on diff --git a/mlx/backend/metal/kernels/steel/attn/nax.h b/mlx/backend/metal/kernels/steel/attn/nax.h new file mode 100644 index 000000000..27a2dac59 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/nax.h @@ -0,0 +1,1079 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/metal/kernels/steel/defines.h" +#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" + +#include + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +/////////////////////////////////////////////////////////////////////////////// +// NAX Steel with new tiles +/////////////////////////////////////////////////////////////////////////////// + +struct BaseNAXFrag { + STEEL_CONST short kFragRows = 16; + STEEL_CONST short kFragCols = 16; + + STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST short kElemRows = 2; + STEEL_CONST short kElemCols = 4; + + STEEL_CONST short kElemRowsJump = 8; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + template + using dtype_mat_t = typename metal::simdgroup_matrix; + + template + using dtype_frag_t = typename metal::vec; + + METAL_FUNC static short2 get_coord() { + const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); + const short qid = simd_lane_id >> 2; + const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)); + const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4; + return short2{fn, fm}; + } + + METAL_FUNC static short2 get_coord(short idx) { + const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); + const short qid = simd_lane_id >> 2; + const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8; + const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4; + return short2{fn, fm}; + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[r * str_x + c + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } + } + } + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load_rows( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if (r < lim_x) { + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j)]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } + } + + } else { + dst = dtype_frag_t(0); + } + } + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load_safe( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if (r < lim_x && (c + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_rows( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if (r < lim_x) { + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_safe( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if (r < lim_x && (c + j) < lim_y) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_slice( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + const short2 sc = short2{0, 0}; // get_coord(); + + const_for_loop<0, kElemRows, 1>([&](auto idx_row) { + const auto r = off_x + idx_row * Int{}; + if (r >= stop_x - sc.y || r < start_x - sc.y) { + return; + } + + const_for_loop<0, kElemCols, 1>([&](auto idx_col) { + const auto c = off_y + idx_col; + if (c >= stop_y - sc.x || c < start_y - sc.x) { + return; + } + + const auto src_idx = idx_row * Int{} + idx_col; + dst[(r + sc.y) * str_x + (c + sc.x) * str_y] = + static_cast(src[src_idx]); + }); + }); + } + + template + METAL_FUNC static constexpr void row_reduce( + thread const dtype_frag_t& inp_vals, + thread T* reduced_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + T thr_reduce = Op::apply( + Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]), + Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3])); + + T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); + qgr_reduce = Op::apply(thr_reduce, qgr_reduce); + + T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); + sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); + + reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce); + } + } + + template + METAL_FUNC static constexpr void row_bin_op( + thread dtype_frag_t& inp_vals, + thread T* row_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + inp_vals[i * kElemCols + j] = + Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); + } + } + } +}; + +template < + typename T, + short kRows_, + short kCols_, + typename NAXFrag_ = BaseNAXFrag> +struct NAXSubTile { + using NAXFrag_t = NAXFrag_; + STEEL_CONST short kRows = kRows_; + STEEL_CONST short kCols = kCols_; + + STEEL_CONST short kFragRows = NAXFrag_t::kFragRows; + STEEL_CONST short kFragCols = NAXFrag_t::kFragCols; + STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag; + + STEEL_CONST short kSubTileRows = kRows / kFragRows; + STEEL_CONST short kSubTileCols = kCols / kFragCols; + + STEEL_CONST short kNumFrags = kSubTileRows * kSubTileCols; + STEEL_CONST short kElemsPerSubTile = kNumFrags * kElemsPerFrag; + + STEEL_CONST int kRowsPerThread = kSubTileRows * NAXFrag_t::kElemRows; + STEEL_CONST int kColsPerThread = kSubTileCols * NAXFrag_t::kElemCols; + + STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows; + STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols; + STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump; + + using frag_type = typename NAXFrag_t::template dtype_frag_t; + + frag_type val_frags[kNumFrags]; + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kSubTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kSubTileCols + j]; + } + + template + METAL_FUNC constexpr thread frag_type& frag_at() { + return val_frags[i * kSubTileCols + j]; + } + + template + METAL_FUNC constexpr const thread frag_type& frag_at() const { + return val_frags[i * kSubTileCols + j]; + } + + METAL_FUNC thread T* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread T* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void row_reduce(thread metal::vec& vals) const { + thread T* vptr = (thread T*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::template row_reduce( + frag_at(i, j), &vptr[i * kFragThrRows]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread metal::vec& vals) { + thread T* vptr = (thread T*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::template row_bin_op( + frag_at(i, j), &vptr[i * kFragThrRows]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load( + SrcPtrType src, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load( + frag_at(i, j), + src, + str_x, + str_y, + off_x + i * kFragRows, + off_y + j * kFragCols); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store( + DstPtrType dst, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store( + frag_at(i, j), + dst, + str_x, + str_y, + off_x + i * kFragRows, + off_y + j * kFragCols); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load_rows( + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load_rows( + frag_at(i, j), + src, + str_x, + str_y, + lim_x, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load_safe( + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load_safe( + frag_at(i, j), + src, + str_x, + str_y, + lim_x, + lim_y, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_rows( + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store_safe( + frag_at(i, j), + dst, + str_x, + str_y, + lim_x, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_safe( + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store_safe( + frag_at(i, j), + dst, + str_x, + str_y, + lim_x, + lim_y, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_slice( + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) const { + const_for_loop<0, kSubTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kSubTileCols, 1>([&](auto idx_col) { + NAXFrag_t::store_slice( + frag_at(), + dst, + str_x, + str_y, + start_x, + stop_x, + start_y, + stop_y, + off_x + idx_row * Int{}, + off_y + idx_col * Int{}); + }); + }); + } +}; + +template < + short RC, + short CC, + short RA, + short CA, + short RB, + short CB, + typename CType, + typename AType, + typename BType, + bool transpose_a, + bool transpose_b, + typename NAXFrag_t = BaseNAXFrag> +METAL_FUNC void subtile_matmad_nax( + thread NAXSubTile& C, + thread NAXSubTile& A, + metal::bool_constant, + thread NAXSubTile& B, + metal::bool_constant) { + // Static checks + constexpr short FMa = transpose_a ? CA : RA; + constexpr short FMc = RC; + static_assert(FMa == FMc, "NAX matmul: M dimensions do not match"); + + constexpr short FNb = transpose_b ? RB : CB; + constexpr short FNc = CC; + static_assert(FNb == FNc, "NAX matmul: N dimensions do not match"); + + constexpr short FKa = transpose_a ? RA : CA; + constexpr short FKb = transpose_b ? CB : RB; + static_assert(FKa == FKb, "NAX matmul: N dimensions do not match"); + + constexpr short FM = FMc; + constexpr short FN = FNc; + constexpr short FK = FKa; + + constexpr int TM = FM / 16; + constexpr int TN = FN / 16; + constexpr int TK = FK / 16; + + constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( + FM, + FN, + FK, + transpose_a, + transpose_b, + true, + mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); + + mpp::tensor_ops::matmul2d gemm_op; + + auto ct_a = + gemm_op.template get_left_input_cooperative_tensor(); + auto ct_b = + gemm_op + .template get_right_input_cooperative_tensor(); + auto ct_c = gemm_op.template get_destination_cooperative_tensor< + decltype(ct_a), + decltype(ct_b), + CType>(); + + STEEL_PRAGMA_UNROLL + for (short mm = 0; mm < TM; mm++) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; kk++) { + const short fi = transpose_a ? kk : mm; + const short fj = transpose_a ? mm : kk; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < 8; i++) { + ct_a[(TK * mm + kk) * 8 + i] = A.frag_at(fi, fj)[i]; + } + } + } + + STEEL_PRAGMA_UNROLL + for (short nn = 0; nn < TN; nn++) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; kk++) { + const short fi = transpose_b ? nn : kk; + const short fj = transpose_b ? kk : nn; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < 8; i++) { + ct_b[(TN * kk + nn) * 8 + i] = B.frag_at(fi, fj)[i]; + } + } + } + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < ct_c.get_capacity(); i++) { + ct_c[i] = C.elems()[i]; + } + + gemm_op.run(ct_a, ct_b, ct_c); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < ct_c.get_capacity(); i++) { + C.elems()[i] = ct_c[i]; + } +} + +template +struct NAXTile { + using NAXSubTile_t = NAXSubTile_; + using elem_type = T; + STEEL_CONST short kSubTileRows = NAXSubTile_t::kRows; + STEEL_CONST short kSubTileCols = NAXSubTile_t::kCols; + STEEL_CONST short kElemsPerSubTile = NAXSubTile_t::kElemsPerSubTile; + + STEEL_CONST short kTileRows = kTileRows_; + STEEL_CONST short kTileCols = kTileCols_; + + STEEL_CONST short kRows = kTileRows * kSubTileRows; + STEEL_CONST short kCols = kTileCols * kSubTileCols; + + STEEL_CONST short kSubTiles = kTileRows * kTileCols; + STEEL_CONST short kElemsPerTile = kSubTiles * kElemsPerSubTile; + + STEEL_CONST short kRowsPerThread = kTileRows * NAXSubTile_t::kRowsPerThread; + STEEL_CONST short kColsPerThread = kTileCols * NAXSubTile_t::kColsPerThread; + + STEEL_CONST short kSubTileThrRows = NAXSubTile_t::kRowsPerThread; + STEEL_CONST short kSubTileThrCols = NAXSubTile_t::kColsPerThread; + + NAXSubTile_t val_subtiles[kSubTiles]; + + METAL_FUNC NAXTile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTiles; ++i) { + val_subtiles[i].clear(); + } + } + + METAL_FUNC constexpr thread NAXSubTile_t& subtile_at( + const short i, + const short j) { + return val_subtiles[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at( + const short i, + const short j) const { + return val_subtiles[i * kTileCols + j]; + } + + template + METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at() const { + return val_subtiles[i * kTileCols + j]; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_subtiles[0].elems()); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_subtiles[0].elems()); + } + + template + METAL_FUNC void row_reduce(thread metal::vec& vals) const { + auto sub_rows = (thread metal::vec*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).template row_reduce(sub_rows[i]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread metal::vec& vals) { + auto sub_rows = (thread metal::vec*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).template row_bin_op(sub_rows[i]); + } + } + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load( + src, + Int{}, + Int{}, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store( + dst, + Int{}, + Int{}, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load( + &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store( + &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load_safe( + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void + load_rows(const device U* src, const int ld, const short n_rows) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load_rows( + &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], + ld, + Int<1>{}, + n_rows - i * kSubTileRows); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store_safe( + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows) + const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store_rows( + &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], + ld, + Int<1>{}, + n_rows - i * kSubTileRows); + } + } + } + + template + METAL_FUNC void store_slice( + device U* dst, + const int ld, + const short2 start, + const short2 stop) const { + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + subtile_at().store_slice( + dst, + ld, + Int<1>{}, + start.y, + stop.y, + start.x, + stop.x, + idx_row * Int{}, + idx_col * Int{}); + }); + }); + } +}; + +template < + class CTile, + class ATile, + class BTile, + bool transpose_a, + bool transpose_b> +METAL_FUNC void tile_matmad_nax( + thread CTile& C, + thread ATile& A, + metal::bool_constant, + thread BTile& B, + metal::bool_constant) { + // Static checks + constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows; + constexpr short TMc = CTile::kTileRows; + static_assert(TMa == TMc, "NAX tile matmul: M dimensions do not match"); + + constexpr short FMa = transpose_a ? ATile::kSubTileCols : ATile::kSubTileRows; + constexpr short FMc = CTile::kSubTileRows; + static_assert(FMa == FMc, "NAX subtile matmul: M dimensions do not match"); + + constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols; + constexpr short TNc = CTile::kTileCols; + static_assert(TNb == TNc, "NAX tile matmul: N dimensions do not match"); + + constexpr short FNb = transpose_b ? BTile::kSubTileRows : BTile::kSubTileCols; + constexpr short FNc = CTile::kSubTileCols; + static_assert(FNb == FNc, "NAX subtile matmul: N dimensions do not match"); + + constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols; + constexpr short TKb = transpose_b ? BTile::kTileCols : BTile::kTileRows; + static_assert(TKa == TKb, "NAX tile matmul: K dimensions do not match"); + + constexpr short FKa = transpose_a ? ATile::kSubTileRows : ATile::kSubTileCols; + constexpr short FKb = transpose_b ? BTile::kSubTileCols : BTile::kSubTileRows; + static_assert(FKa == FKb, "NAX subtile matmul: K dimensions do not match"); + + constexpr short TM = TMc; + constexpr short TN = TNc; + constexpr short TK = TKa; + + // Do matmul here + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; ++j) { + STEEL_PRAGMA_UNROLL + for (short k = 0; k < TK; ++k) { + const short ra = transpose_a ? k : i; + const short ca = transpose_a ? i : k; + const short rb = transpose_b ? j : k; + const short cb = transpose_b ? k : j; + + subtile_matmad_nax( + C.subtile_at(i, j), + A.subtile_at(ra, ca), + metal::bool_constant{}, + B.subtile_at(rb, cb), + metal::bool_constant{}); + } + } + } +} + +} // namespace steel +} // namespace mlx diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index d8adf8199..2eac8551f 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -12,6 +12,146 @@ namespace mlx::core::fast { namespace { + +#ifdef MLX_ENABLE_NAX + +void sdpa_full_self_attention_nax( + const Stream& s, + metal::Device& d, + const array& q, + const array& k, + const array& v, + const float scale, + array& o, + bool do_causal_, + const std::optional& mask, + const std::optional& sinks) { + using namespace mlx::steel; + + int wm = 4; + int wn = 1; + + int bd = q.shape(-1); + int bq = 64; + int bk = 32; + + int B = q.shape(0); + int H = q.shape(1); + int D = q.shape(3); + int gqa_factor = q.shape(1) / k.shape(1); + + int qL = q.shape(2); + int kL = k.shape(2); + + const bool align_Q = (qL % bq) == 0; + const bool align_K = (kL % bk) == 0; + const bool has_mask = mask.has_value(); + const bool do_causal = do_causal_; + const bool has_sinks = sinks.has_value(); + + metal::MTLFCList func_consts = { + {&align_Q, MTL::DataType::DataTypeBool, 200}, + {&align_K, MTL::DataType::DataTypeBool, 201}, + {&has_mask, MTL::DataType::DataTypeBool, 300}, + {&do_causal, MTL::DataType::DataTypeBool, 301}, + {&has_sinks, MTL::DataType::DataTypeBool, 302}}; + + std::string base_name; + concatenate( + base_name, + "steel_attention_", + type_to_name(q), + "_bq", + bq, + "_bk", + bk, + "_bd", + bd, + "_wm", + wm, + "_wn", + wn, + "_mask", + type_to_name(has_mask ? *mask : q)); + + std::string hash_name; + concatenate( + hash_name, + base_name, + "_align_Q_", + (align_Q ? 't' : 'n'), + "_align_K_", + (align_K ? 't' : 'n'), + "_has_mask_", + (has_mask ? 't' : 'n'), + "_do_causal_", + (do_causal ? 't' : 'n'), + "_has_sinks_", + (has_sinks ? 't' : 'n')); + + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(base_name, hash_name, func_consts); + compute_encoder.set_compute_pipeline_state(kernel); + + const int NQ = (qL + bq - 1) / bq; + const int NK = (kL + bk - 1) / bk; + + const int NQ_aligned = qL / bq; + const int NK_aligned = kL / bk; + + AttnParams params{ + /* int B = */ B, + /* int H = */ H, + /* int D = */ D, + + /* int qL = */ qL, + /* int kL = */ kL, + + /* int gqa_factor = */ gqa_factor, + /* float scale = */ scale, + + /* int NQ = */ NQ, + /* int NK = */ NK, + + /* int NQ_aligned = */ NQ_aligned, + /* int NK_aligned = */ NK_aligned, + + /* int qL_rem = */ (qL - NQ_aligned * bq), + /* int kL_rem = */ (kL - NK_aligned * bk), + /* int qL_off = */ (kL - qL), + + /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, + /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, + /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, + /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; + + compute_encoder.set_input_array(q, 0); + compute_encoder.set_input_array(k, 1); + compute_encoder.set_input_array(v, 2); + compute_encoder.set_output_array(o, 3); + compute_encoder.set_bytes(params, 4); + + if (has_mask) { + auto& m = *mask; + + AttnMaskParams mask_params{/* int64_t M_strides[3] = */ { + m.strides(0), m.strides(1), m.strides(2)}}; + + compute_encoder.set_bytes(mask_params, 5); + compute_encoder.set_input_array(m, 6); + } + if (has_sinks) { + compute_encoder.set_input_array(*sinks, 7); + } + + MTL::Size grid_dims = MTL::Size(NQ, H, B); + MTL::Size group_dims = MTL::Size(32, wm, wn); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +#endif // MLX_ENABLE_NAX + void sdpa_full_self_attention_metal( const Stream& s, metal::Device& d, @@ -23,6 +163,23 @@ void sdpa_full_self_attention_metal( bool do_causal_, const std::optional& mask, const std::optional& sinks) { +#ifdef MLX_ENABLE_NAX + if (metal::is_nax_available() && q.shape(3) != 80 && + (q.dtype() != float32 || env::enable_tf32())) { + return sdpa_full_self_attention_nax( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& q = */ q, + /* const array& k = */ k, + /* const array& v = */ v, + /* const float scale = */ scale, + /* array& o = */ o, + /* bool do_causal_ = */ do_causal_, + /* const std::optional& mask = */ mask, + /* const std::optional& sinks = */ sinks); + } +#endif // MLX_ENABLE_NAX + using namespace mlx::steel; int wm = 4;