// Copyright © 2025 Apple Inc. #pragma once #include #include #include #include "mlx/backend/metal/kernels/steel/defines.h" #include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" #include using namespace metal; /////////////////////////////////////////////////////////////////////////////// // MMA helper /////////////////////////////////////////////////////////////////////////////// namespace mlx { namespace steel { /////////////////////////////////////////////////////////////////////////////// // NAX Steel with new tiles /////////////////////////////////////////////////////////////////////////////// struct BaseNAXFrag { STEEL_CONST short kFragRows = 16; STEEL_CONST short kFragCols = 16; STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32; STEEL_CONST short kElemRows = 2; STEEL_CONST short kElemCols = 4; STEEL_CONST short kElemRowsJump = 8; static_assert( kElemRows * kElemCols == kElemsPerFrag, "MMAFrag shape is not consistent with MMAFrag size"); template using dtype_mat_t = typename metal::simdgroup_matrix; template using dtype_frag_t = typename metal::vec; METAL_FUNC static short2 get_coord() { const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); const short qid = simd_lane_id >> 2; const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)); const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4; return short2{fn, fm}; } METAL_FUNC static short2 get_coord(short idx) { const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); const short qid = simd_lane_id >> 2; const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8; const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4; return short2{fn, fm}; } template < typename T, typename SrcPtrType, typename StrX, typename StrY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void load( thread dtype_frag_t& dst, SrcPtrType src, StrX str_x, StrY str_y, OffX off_x = {}, OffY off_y = {}) { const short2 sc = short2{0, 0}; // get_coord(); STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump + sc.y; const auto c = off_y + sc.x; if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * kElemCols + j] = static_cast(src[r * str_x + c + j]); } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j) * str_y]); } } } } template < typename T, typename SrcPtrType, typename StrX, typename StrY, typename LimX, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void load_rows( thread dtype_frag_t& dst, SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, OffX off_x = {}, OffY off_y = {}) { const short2 sc = short2{0, 0}; // get_coord(); STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump + sc.y; const auto c = off_y + sc.x; if (r < lim_x) { if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j)]); } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j) * str_y]); } } } else { dst = dtype_frag_t(0); } } } template < typename T, typename SrcPtrType, typename StrX, typename StrY, typename LimX, typename LimY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void load_safe( thread dtype_frag_t& dst, SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x = {}, OffY off_y = {}) { const short2 sc = short2{0, 0}; // get_coord(); STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump + sc.y; const auto c = off_y + sc.x; STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { if (r < lim_x && (c + j) < lim_y) { dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j) * str_y]); } else { dst[i * kElemCols + j] = T(0); } } } } template < typename T, typename DstPtrType, typename StrX, typename StrY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void store( const thread dtype_frag_t& src, DstPtrType dst, StrX str_x, StrY str_y, OffX off_x = {}, OffY off_y = {}) { using U = pointer_element_t; const short2 sc = short2{0, 0}; // get_coord(); STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump + sc.y; const auto c = off_y + sc.x; if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[r * str_x + (c + j) * str_y] = static_cast(src[i * kElemCols + j]); } } } } template < typename T, typename DstPtrType, typename StrX, typename StrY, typename LimX, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void store_rows( const thread dtype_frag_t& src, DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, OffX off_x = {}, OffY off_y = {}) { using U = pointer_element_t; const short2 sc = short2{0, 0}; // get_coord(); STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump + sc.y; const auto c = off_y + sc.x; if (r < lim_x) { if constexpr (metal::is_same_v>) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); } } else { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { dst[r * str_x + (c + j) * str_y] = static_cast(src[i * kElemCols + j]); } } } } } template < typename T, typename DstPtrType, typename StrX, typename StrY, typename LimX, typename LimY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void store_safe( const thread dtype_frag_t& src, DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x = {}, OffY off_y = {}) { using U = pointer_element_t; const short2 sc = short2{0, 0}; // get_coord(); STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { const auto r = off_x + i * kElemRowsJump + sc.y; const auto c = off_y + sc.x; STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { if (r < lim_x && (c + j) < lim_y) { dst[r * str_x + (c + j) * str_y] = static_cast(src[i * kElemCols + j]); } } } } template < typename T, typename DstPtrType, typename StrX, typename StrY, typename StartX, typename StopX, typename StartY, typename StopY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC static constexpr void store_slice( const thread dtype_frag_t& src, DstPtrType dst, StrX str_x, StrY str_y, StartX start_x, StopX stop_x, StartY start_y, StopY stop_y, OffX off_x = Int<0>{}, OffY off_y = Int<0>{}) { using U = pointer_element_t; const short2 sc = short2{0, 0}; // get_coord(); const_for_loop<0, kElemRows, 1>([&](auto idx_row) { const auto r = off_x + idx_row * Int{}; if (r >= stop_x - sc.y || r < start_x - sc.y) { return; } const_for_loop<0, kElemCols, 1>([&](auto idx_col) { const auto c = off_y + idx_col; if (c >= stop_y - sc.x || c < start_y - sc.x) { return; } const auto src_idx = idx_row * Int{} + idx_col; dst[(r + sc.y) * str_x + (c + sc.x) * str_y] = static_cast(src[src_idx]); }); }); } template METAL_FUNC static constexpr void row_reduce( thread const dtype_frag_t& inp_vals, thread T* reduced_vals) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { T thr_reduce = Op::apply( Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]), Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3])); T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); qgr_reduce = Op::apply(thr_reduce, qgr_reduce); T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce); } } template METAL_FUNC static constexpr void row_bin_op( thread dtype_frag_t& inp_vals, thread T* row_vals) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kElemRows; i++) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kElemCols; j++) { inp_vals[i * kElemCols + j] = Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); } } } }; template < typename T, short kRows_, short kCols_, typename NAXFrag_ = BaseNAXFrag> struct NAXSubTile { using NAXFrag_t = NAXFrag_; STEEL_CONST short kRows = kRows_; STEEL_CONST short kCols = kCols_; STEEL_CONST short kFragRows = NAXFrag_t::kFragRows; STEEL_CONST short kFragCols = NAXFrag_t::kFragCols; STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag; STEEL_CONST short kSubTileRows = kRows / kFragRows; STEEL_CONST short kSubTileCols = kCols / kFragCols; STEEL_CONST short kNumFrags = kSubTileRows * kSubTileCols; STEEL_CONST short kElemsPerSubTile = kNumFrags * kElemsPerFrag; STEEL_CONST int kRowsPerThread = kSubTileRows * NAXFrag_t::kElemRows; STEEL_CONST int kColsPerThread = kSubTileCols * NAXFrag_t::kElemCols; STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows; STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols; STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump; using frag_type = typename NAXFrag_t::template dtype_frag_t; frag_type val_frags[kNumFrags]; METAL_FUNC constexpr void clear() { STEEL_PRAGMA_UNROLL for (short i = 0; i < kNumFrags; ++i) { val_frags[i] = frag_type(0); } } METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { return val_frags[i * kSubTileCols + j]; } METAL_FUNC constexpr const thread frag_type& frag_at( const short i, const short j) const { return val_frags[i * kSubTileCols + j]; } template METAL_FUNC constexpr thread frag_type& frag_at() { return val_frags[i * kSubTileCols + j]; } template METAL_FUNC constexpr const thread frag_type& frag_at() const { return val_frags[i * kSubTileCols + j]; } METAL_FUNC thread T* elems() { return reinterpret_cast(val_frags); } METAL_FUNC const thread T* elems() const { return reinterpret_cast(val_frags); } template METAL_FUNC void row_reduce(thread metal::vec& vals) const { thread T* vptr = (thread T*)(&vals); STEEL_PRAGMA_UNROLL for (short i = 0; i < kSubTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kSubTileCols; ++j) { NAXFrag_t::template row_reduce( frag_at(i, j), &vptr[i * kFragThrRows]); } } } template METAL_FUNC void row_bin_op(thread metal::vec& vals) { thread T* vptr = (thread T*)(&vals); STEEL_PRAGMA_UNROLL for (short i = 0; i < kSubTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kSubTileCols; ++j) { NAXFrag_t::template row_bin_op( frag_at(i, j), &vptr[i * kFragThrRows]); } } } template < typename SrcPtrType, typename StrX, typename StrY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC constexpr void load( SrcPtrType src, StrX str_x, StrY str_y, OffX off_x = {}, OffY off_y = {}) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kSubTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kSubTileCols; ++j) { NAXFrag_t::load( frag_at(i, j), src, str_x, str_y, off_x + i * kFragRows, off_y + j * kFragCols); } } } template < typename DstPtrType, typename StrX, typename StrY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC constexpr void store( DstPtrType dst, StrX str_x, StrY str_y, OffX off_x = {}, OffY off_y = {}) const { STEEL_PRAGMA_UNROLL for (short i = 0; i < kSubTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kSubTileCols; ++j) { NAXFrag_t::store( frag_at(i, j), dst, str_x, str_y, off_x + i * kFragRows, off_y + j * kFragCols); } } } template < typename SrcPtrType, typename StrX, typename StrY, typename LimX, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC constexpr void load_rows( SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, OffX off_x = {}, OffY off_y = {}) { STEEL_PRAGMA_UNROLL for (int i = 0; i < kSubTileRows; ++i) { STEEL_PRAGMA_UNROLL for (int j = 0; j < kSubTileCols; ++j) { NAXFrag_t::load_rows( frag_at(i, j), src, str_x, str_y, lim_x, off_x + (i * kFragRows), off_y + (j * kFragCols)); } } } template < typename SrcPtrType, typename StrX, typename StrY, typename LimX, typename LimY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC constexpr void load_safe( SrcPtrType src, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x = {}, OffY off_y = {}) { STEEL_PRAGMA_UNROLL for (int i = 0; i < kSubTileRows; ++i) { STEEL_PRAGMA_UNROLL for (int j = 0; j < kSubTileCols; ++j) { NAXFrag_t::load_safe( frag_at(i, j), src, str_x, str_y, lim_x, lim_y, off_x + (i * kFragRows), off_y + (j * kFragCols)); } } } template < typename DstPtrType, typename StrX, typename StrY, typename LimX, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC constexpr void store_rows( DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, OffX off_x = {}, OffY off_y = {}) const { STEEL_PRAGMA_UNROLL for (int i = 0; i < kSubTileRows; ++i) { STEEL_PRAGMA_UNROLL for (int j = 0; j < kSubTileCols; ++j) { NAXFrag_t::store_safe( frag_at(i, j), dst, str_x, str_y, lim_x, off_x + (i * kFragRows), off_y + (j * kFragCols)); } } } template < typename DstPtrType, typename StrX, typename StrY, typename LimX, typename LimY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC constexpr void store_safe( DstPtrType dst, StrX str_x, StrY str_y, LimX lim_x, LimY lim_y, OffX off_x = {}, OffY off_y = {}) const { STEEL_PRAGMA_UNROLL for (int i = 0; i < kSubTileRows; ++i) { STEEL_PRAGMA_UNROLL for (int j = 0; j < kSubTileCols; ++j) { NAXFrag_t::store_safe( frag_at(i, j), dst, str_x, str_y, lim_x, lim_y, off_x + (i * kFragRows), off_y + (j * kFragCols)); } } } template < typename DstPtrType, typename StrX, typename StrY, typename StartX, typename StopX, typename StartY, typename StopY, typename OffX = Int<0>, typename OffY = Int<0>> METAL_FUNC constexpr void store_slice( DstPtrType dst, StrX str_x, StrY str_y, StartX start_x, StopX stop_x, StartY start_y, StopY stop_y, OffX off_x = Int<0>{}, OffY off_y = Int<0>{}) const { const_for_loop<0, kSubTileRows, 1>([&](auto idx_row) { const_for_loop<0, kSubTileCols, 1>([&](auto idx_col) { NAXFrag_t::store_slice( frag_at(), dst, str_x, str_y, start_x, stop_x, start_y, stop_y, off_x + idx_row * Int{}, off_y + idx_col * Int{}); }); }); } }; template < short RC, short CC, short RA, short CA, short RB, short CB, typename CType, typename AType, typename BType, bool transpose_a, bool transpose_b, typename NAXFrag_t = BaseNAXFrag> METAL_FUNC void subtile_matmad_nax( thread NAXSubTile& C, thread NAXSubTile& A, metal::bool_constant, thread NAXSubTile& B, metal::bool_constant) { // Static checks constexpr short FMa = transpose_a ? CA : RA; constexpr short FMc = RC; static_assert(FMa == FMc, "NAX matmul: M dimensions do not match"); constexpr short FNb = transpose_b ? RB : CB; constexpr short FNc = CC; static_assert(FNb == FNc, "NAX matmul: N dimensions do not match"); constexpr short FKa = transpose_a ? RA : CA; constexpr short FKb = transpose_b ? CB : RB; static_assert(FKa == FKb, "NAX matmul: N dimensions do not match"); constexpr short FM = FMc; constexpr short FN = FNc; constexpr short FK = FKa; constexpr int TM = FM / 16; constexpr int TN = FN / 16; constexpr int TK = FK / 16; constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( FM, FN, FK, transpose_a, transpose_b, true, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); mpp::tensor_ops::matmul2d gemm_op; auto ct_a = gemm_op.template get_left_input_cooperative_tensor(); auto ct_b = gemm_op .template get_right_input_cooperative_tensor(); auto ct_c = gemm_op.template get_destination_cooperative_tensor< decltype(ct_a), decltype(ct_b), CType>(); STEEL_PRAGMA_UNROLL for (short mm = 0; mm < TM; mm++) { STEEL_PRAGMA_UNROLL for (short kk = 0; kk < TK; kk++) { const short fi = transpose_a ? kk : mm; const short fj = transpose_a ? mm : kk; STEEL_PRAGMA_UNROLL for (short i = 0; i < 8; i++) { ct_a[(TK * mm + kk) * 8 + i] = A.frag_at(fi, fj)[i]; } } } STEEL_PRAGMA_UNROLL for (short nn = 0; nn < TN; nn++) { STEEL_PRAGMA_UNROLL for (short kk = 0; kk < TK; kk++) { const short fi = transpose_b ? nn : kk; const short fj = transpose_b ? kk : nn; STEEL_PRAGMA_UNROLL for (short i = 0; i < 8; i++) { ct_b[(TN * kk + nn) * 8 + i] = B.frag_at(fi, fj)[i]; } } } STEEL_PRAGMA_UNROLL for (short i = 0; i < ct_c.get_capacity(); i++) { ct_c[i] = C.elems()[i]; } gemm_op.run(ct_a, ct_b, ct_c); STEEL_PRAGMA_UNROLL for (short i = 0; i < ct_c.get_capacity(); i++) { C.elems()[i] = ct_c[i]; } } template struct NAXTile { using NAXSubTile_t = NAXSubTile_; using elem_type = T; STEEL_CONST short kSubTileRows = NAXSubTile_t::kRows; STEEL_CONST short kSubTileCols = NAXSubTile_t::kCols; STEEL_CONST short kElemsPerSubTile = NAXSubTile_t::kElemsPerSubTile; STEEL_CONST short kTileRows = kTileRows_; STEEL_CONST short kTileCols = kTileCols_; STEEL_CONST short kRows = kTileRows * kSubTileRows; STEEL_CONST short kCols = kTileCols * kSubTileCols; STEEL_CONST short kSubTiles = kTileRows * kTileCols; STEEL_CONST short kElemsPerTile = kSubTiles * kElemsPerSubTile; STEEL_CONST short kRowsPerThread = kTileRows * NAXSubTile_t::kRowsPerThread; STEEL_CONST short kColsPerThread = kTileCols * NAXSubTile_t::kColsPerThread; STEEL_CONST short kSubTileThrRows = NAXSubTile_t::kRowsPerThread; STEEL_CONST short kSubTileThrCols = NAXSubTile_t::kColsPerThread; NAXSubTile_t val_subtiles[kSubTiles]; METAL_FUNC NAXTile() thread {} METAL_FUNC constexpr void clear() { STEEL_PRAGMA_UNROLL for (short i = 0; i < kSubTiles; ++i) { val_subtiles[i].clear(); } } METAL_FUNC constexpr thread NAXSubTile_t& subtile_at( const short i, const short j) { return val_subtiles[i * kTileCols + j]; } METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at( const short i, const short j) const { return val_subtiles[i * kTileCols + j]; } template METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at() const { return val_subtiles[i * kTileCols + j]; } METAL_FUNC thread elem_type* elems() { return reinterpret_cast(val_subtiles[0].elems()); } METAL_FUNC const thread elem_type* elems() const { return reinterpret_cast(val_subtiles[0].elems()); } template METAL_FUNC void row_reduce(thread metal::vec& vals) const { auto sub_rows = (thread metal::vec*)(&vals); STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { subtile_at(i, j).template row_reduce(sub_rows[i]); } } } template METAL_FUNC void row_bin_op(thread metal::vec& vals) { auto sub_rows = (thread metal::vec*)(&vals); STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { subtile_at(i, j).template row_bin_op(sub_rows[i]); } } } template METAL_FUNC void load(const threadgroup U* src) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { subtile_at(i, j).load( src, Int{}, Int{}, i * kSubTileRows, j * kSubTileCols); } } } template METAL_FUNC void store(threadgroup U* dst) const { STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { subtile_at(i, j).store( dst, Int{}, Int{}, i * kSubTileRows, j * kSubTileCols); } } } template METAL_FUNC void load(const device U* src, const int ld) { STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { subtile_at(i, j).load( &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{}); } } } template METAL_FUNC void store(device U* dst, const int ld) const { STEEL_PRAGMA_UNROLL for (short i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < kTileCols; ++j) { subtile_at(i, j).store( &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{}); } } } template METAL_FUNC void load_safe(const device U* src, const int ld, const short2 src_tile_dims) { STEEL_PRAGMA_UNROLL for (int i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (int j = 0; j < kTileCols; ++j) { subtile_at(i, j).load_safe( src, ld, Int<1>{}, src_tile_dims.y, src_tile_dims.x, i * kSubTileRows, j * kSubTileCols); } } } template METAL_FUNC void load_rows(const device U* src, const int ld, const short n_rows) { STEEL_PRAGMA_UNROLL for (int i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (int j = 0; j < kTileCols; ++j) { subtile_at(i, j).load_rows( &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{}, n_rows - i * kSubTileRows); } } } template METAL_FUNC void store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { STEEL_PRAGMA_UNROLL for (int i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (int j = 0; j < kTileCols; ++j) { subtile_at(i, j).store_safe( dst, ld, Int<1>{}, dst_tile_dims.y, dst_tile_dims.x, i * kSubTileRows, j * kSubTileCols); } } } template METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows) const { STEEL_PRAGMA_UNROLL for (int i = 0; i < kTileRows; ++i) { STEEL_PRAGMA_UNROLL for (int j = 0; j < kTileCols; ++j) { subtile_at(i, j).store_rows( &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{}, n_rows - i * kSubTileRows); } } } template METAL_FUNC void store_slice( device U* dst, const int ld, const short2 start, const short2 stop) const { const_for_loop<0, kTileRows, 1>([&](auto idx_row) { const_for_loop<0, kTileCols, 1>([&](auto idx_col) { subtile_at().store_slice( dst, ld, Int<1>{}, start.y, stop.y, start.x, stop.x, idx_row * Int{}, idx_col * Int{}); }); }); } }; template < class CTile, class ATile, class BTile, bool transpose_a, bool transpose_b> METAL_FUNC void tile_matmad_nax( thread CTile& C, thread ATile& A, metal::bool_constant, thread BTile& B, metal::bool_constant) { // Static checks constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows; constexpr short TMc = CTile::kTileRows; static_assert(TMa == TMc, "NAX tile matmul: M dimensions do not match"); constexpr short FMa = transpose_a ? ATile::kSubTileCols : ATile::kSubTileRows; constexpr short FMc = CTile::kSubTileRows; static_assert(FMa == FMc, "NAX subtile matmul: M dimensions do not match"); constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols; constexpr short TNc = CTile::kTileCols; static_assert(TNb == TNc, "NAX tile matmul: N dimensions do not match"); constexpr short FNb = transpose_b ? BTile::kSubTileRows : BTile::kSubTileCols; constexpr short FNc = CTile::kSubTileCols; static_assert(FNb == FNc, "NAX subtile matmul: N dimensions do not match"); constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols; constexpr short TKb = transpose_b ? BTile::kTileCols : BTile::kTileRows; static_assert(TKa == TKb, "NAX tile matmul: K dimensions do not match"); constexpr short FKa = transpose_a ? ATile::kSubTileRows : ATile::kSubTileCols; constexpr short FKb = transpose_b ? BTile::kSubTileCols : BTile::kSubTileRows; static_assert(FKa == FKb, "NAX subtile matmul: K dimensions do not match"); constexpr short TM = TMc; constexpr short TN = TNc; constexpr short TK = TKa; // Do matmul here STEEL_PRAGMA_UNROLL for (short i = 0; i < TM; ++i) { STEEL_PRAGMA_UNROLL for (short j = 0; j < TN; ++j) { STEEL_PRAGMA_UNROLL for (short k = 0; k < TK; ++k) { const short ra = transpose_a ? k : i; const short ca = transpose_a ? i : k; const short rb = transpose_b ? j : k; const short cb = transpose_b ? k : j; subtile_matmad_nax( C.subtile_at(i, j), A.subtile_at(ra, ca), metal::bool_constant{}, B.subtile_at(rb, cb), metal::bool_constant{}); } } } } } // namespace steel } // namespace mlx