Files
mlx/mlx/backend/metal/kernels/steel/attn/nax.h
2025-11-19 11:15:52 -08:00

1080 lines
29 KiB
C++

// Copyright © 2025 Apple Inc.
#pragma once
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
///////////////////////////////////////////////////////////////////////////////
// NAX Steel with new tiles
///////////////////////////////////////////////////////////////////////////////
struct BaseNAXFrag {
STEEL_CONST short kFragRows = 16;
STEEL_CONST short kFragCols = 16;
STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32;
STEEL_CONST short kElemRows = 2;
STEEL_CONST short kElemCols = 4;
STEEL_CONST short kElemRowsJump = 8;
static_assert(
kElemRows * kElemCols == kElemsPerFrag,
"MMAFrag shape is not consistent with MMAFrag size");
template <typename U>
using dtype_mat_t = typename metal::simdgroup_matrix<U, kFragRows, kFragCols>;
template <typename U>
using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;
METAL_FUNC static short2 get_coord() {
const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort());
const short qid = simd_lane_id >> 2;
const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3));
const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4;
return short2{fn, fm};
}
METAL_FUNC static short2 get_coord(short idx) {
const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort());
const short qid = simd_lane_id >> 2;
const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8;
const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4;
return short2{fn, fm};
}
template <
typename T,
typename SrcPtrType,
typename StrX,
typename StrY,
typename OffX = Int<0>,
typename OffY = Int<0>>
METAL_FUNC static constexpr void load(
thread dtype_frag_t<T>& dst,
SrcPtrType src,
StrX str_x,
StrY str_y,
OffX off_x = {},
OffY off_y = {}) {
const short2 sc = short2{0, 0}; // get_coord();
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
const auto r = off_x + i * kElemRowsJump + sc.y;
const auto c = off_y + sc.x;
if constexpr (metal::is_same_v<StrY, Int<1>>) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * kElemCols + j] = static_cast<T>(src[r * str_x + c + j]);
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * kElemCols + j] =
static_cast<T>(src[r * str_x + (c + j) * str_y]);
}
}
}
}
template <
typename T,
typename SrcPtrType,
typename StrX,
typename StrY,
typename LimX,
typename OffX = Int<0>,
typename OffY = Int<0>>
METAL_FUNC static constexpr void load_rows(
thread dtype_frag_t<T>& dst,
SrcPtrType src,
StrX str_x,
StrY str_y,
LimX lim_x,
OffX off_x = {},
OffY off_y = {}) {
const short2 sc = short2{0, 0}; // get_coord();
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
const auto r = off_x + i * kElemRowsJump + sc.y;
const auto c = off_y + sc.x;
if (r < lim_x) {
if constexpr (metal::is_same_v<StrY, Int<1>>) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * kElemCols + j] = static_cast<T>(src[r * str_x + (c + j)]);
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * kElemCols + j] =
static_cast<T>(src[r * str_x + (c + j) * str_y]);
}
}
} else {
dst = dtype_frag_t<T>(0);
}
}
}
template <
typename T,
typename SrcPtrType,
typename StrX,
typename StrY,
typename LimX,
typename LimY,
typename OffX = Int<0>,
typename OffY = Int<0>>
METAL_FUNC static constexpr void load_safe(
thread dtype_frag_t<T>& dst,
SrcPtrType src,
StrX str_x,
StrY str_y,
LimX lim_x,
LimY lim_y,
OffX off_x = {},
OffY off_y = {}) {
const short2 sc = short2{0, 0}; // get_coord();
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
const auto r = off_x + i * kElemRowsJump + sc.y;
const auto c = off_y + sc.x;
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if (r < lim_x && (c + j) < lim_y) {
dst[i * kElemCols + j] =
static_cast<T>(src[r * str_x + (c + j) * str_y]);
} else {
dst[i * kElemCols + j] = T(0);
}
}
}
}
template <
typename T,
typename DstPtrType,
typename StrX,
typename StrY,
typename OffX = Int<0>,
typename OffY = Int<0>>
METAL_FUNC static constexpr void store(
const thread dtype_frag_t<T>& src,
DstPtrType dst,
StrX str_x,
StrY str_y,
OffX off_x = {},
OffY off_y = {}) {
using U = pointer_element_t<DstPtrType>;
const short2 sc = short2{0, 0}; // get_coord();
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
const auto r = off_x + i * kElemRowsJump + sc.y;
const auto c = off_y + sc.x;
if constexpr (metal::is_same_v<StrY, Int<1>>) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[r * str_x + c + j] = static_cast<U>(src[i * kElemCols + j]);
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[r * str_x + (c + j) * str_y] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
}
template <
typename T,
typename DstPtrType,
typename StrX,
typename StrY,
typename LimX,
typename OffX = Int<0>,
typename OffY = Int<0>>
METAL_FUNC static constexpr void store_rows(
const thread dtype_frag_t<T>& src,
DstPtrType dst,
StrX str_x,
StrY str_y,
LimX lim_x,
OffX off_x = {},
OffY off_y = {}) {
using U = pointer_element_t<DstPtrType>;
const short2 sc = short2{0, 0}; // get_coord();
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
const auto r = off_x + i * kElemRowsJump + sc.y;
const auto c = off_y + sc.x;
if (r < lim_x) {
if constexpr (metal::is_same_v<StrY, Int<1>>) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[r * str_x + c + j] = static_cast<U>(src[i * kElemCols + j]);
}
} else {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[r * str_x + (c + j) * str_y] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
}
}
template <
typename T,
typename DstPtrType,
typename StrX,
typename StrY,
typename LimX,
typename LimY,
typename OffX = Int<0>,
typename OffY = Int<0>>
METAL_FUNC static constexpr void store_safe(
const thread dtype_frag_t<T>& src,
DstPtrType dst,
StrX str_x,
StrY str_y,
LimX lim_x,
LimY lim_y,
OffX off_x = {},
OffY off_y = {}) {
using U = pointer_element_t<DstPtrType>;
const short2 sc = short2{0, 0}; // get_coord();
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
const auto r = off_x + i * kElemRowsJump + sc.y;
const auto c = off_y + sc.x;
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if (r < lim_x && (c + j) < lim_y) {
dst[r * str_x + (c + j) * str_y] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
}
template <
typename T,
typename DstPtrType,
typename StrX,
typename StrY,
typename StartX,
typename StopX,
typename StartY,
typename StopY,
typename OffX = Int<0>,
typename OffY = Int<0>>
METAL_FUNC static constexpr void store_slice(
const thread dtype_frag_t<T>& src,
DstPtrType dst,
StrX str_x,
StrY str_y,
StartX start_x,
StopX stop_x,
StartY start_y,
StopY stop_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
using U = pointer_element_t<DstPtrType>;
const short2 sc = short2{0, 0}; // get_coord();
const_for_loop<0, kElemRows, 1>([&](auto idx_row) {
const auto r = off_x + idx_row * Int<kElemRowsJump>{};
if (r >= stop_x - sc.y || r < start_x - sc.y) {
return;
}
const_for_loop<0, kElemCols, 1>([&](auto idx_col) {
const auto c = off_y + idx_col;
if (c >= stop_y - sc.x || c < start_y - sc.x) {
return;
}
const auto src_idx = idx_row * Int<kElemCols>{} + idx_col;
dst[(r + sc.y) * str_x + (c + sc.x) * str_y] =
static_cast<U>(src[src_idx]);
});
});
}
template <typename Op, typename T>
METAL_FUNC static constexpr void row_reduce(
thread const dtype_frag_t<T>& inp_vals,
thread T* reduced_vals) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
T thr_reduce = Op::apply(
Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]),
Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3]));
T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1));
qgr_reduce = Op::apply(thr_reduce, qgr_reduce);
T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8));
sgr_reduce = Op::apply(qgr_reduce, sgr_reduce);
reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce);
}
}
template <typename Op, typename T>
METAL_FUNC static constexpr void row_bin_op(
thread dtype_frag_t<T>& inp_vals,
thread T* row_vals) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
inp_vals[i * kElemCols + j] =
Op::apply(inp_vals[i * kElemCols + j], row_vals[i]);
}
}
}
};
template <
typename T,
short kRows_,
short kCols_,
typename NAXFrag_ = BaseNAXFrag>
struct NAXSubTile {
using NAXFrag_t = NAXFrag_;
STEEL_CONST short kRows = kRows_;
STEEL_CONST short kCols = kCols_;
STEEL_CONST short kFragRows = NAXFrag_t::kFragRows;
STEEL_CONST short kFragCols = NAXFrag_t::kFragCols;
STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag;
STEEL_CONST short kSubTileRows = kRows / kFragRows;
STEEL_CONST short kSubTileCols = kCols / kFragCols;
STEEL_CONST short kNumFrags = kSubTileRows * kSubTileCols;
STEEL_CONST short kElemsPerSubTile = kNumFrags * kElemsPerFrag;
STEEL_CONST int kRowsPerThread = kSubTileRows * NAXFrag_t::kElemRows;
STEEL_CONST int kColsPerThread = kSubTileCols * NAXFrag_t::kElemCols;
STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows;
STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols;
STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump;
using frag_type = typename NAXFrag_t::template dtype_frag_t<T>;
frag_type val_frags[kNumFrags];
METAL_FUNC constexpr void clear() {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kNumFrags; ++i) {
val_frags[i] = frag_type(0);
}
}
METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
return val_frags[i * kSubTileCols + j];
}
METAL_FUNC constexpr const thread frag_type& frag_at(
const short i,
const short j) const {
return val_frags[i * kSubTileCols + j];
}
template <int i, int j>
METAL_FUNC constexpr thread frag_type& frag_at() {
return val_frags[i * kSubTileCols + j];
}
template <int i, int j>
METAL_FUNC constexpr const thread frag_type& frag_at() const {
return val_frags[i * kSubTileCols + j];
}
METAL_FUNC thread T* elems() {
return reinterpret_cast<thread T*>(val_frags);
}
METAL_FUNC const thread T* elems() const {
return reinterpret_cast<const thread T*>(val_frags);
}
template <typename Op>
METAL_FUNC void row_reduce(thread metal::vec<T, kRowsPerThread>& vals) const {
thread T* vptr = (thread T*)(&vals);
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kSubTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kSubTileCols; ++j) {
NAXFrag_t::template row_reduce<Op>(
frag_at(i, j), &vptr[i * kFragThrRows]);
}
}
}
template <typename Op>
METAL_FUNC void row_bin_op(thread metal::vec<T, kRowsPerThread>& vals) {
thread T* vptr = (thread T*)(&vals);
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kSubTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kSubTileCols; ++j) {
NAXFrag_t::template row_bin_op<Op>(
frag_at(i, j), &vptr[i * kFragThrRows]);
}
}
}
template <
typename SrcPtrType,
typename StrX,
typename StrY,
typename OffX = Int<0>,
typename OffY = Int<0>>
METAL_FUNC constexpr void load(
SrcPtrType src,
StrX str_x,
StrY str_y,
OffX off_x = {},
OffY off_y = {}) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kSubTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kSubTileCols; ++j) {
NAXFrag_t::load(
frag_at(i, j),
src,
str_x,
str_y,
off_x + i * kFragRows,
off_y + j * kFragCols);
}
}
}
template <
typename DstPtrType,
typename StrX,
typename StrY,
typename OffX = Int<0>,
typename OffY = Int<0>>
METAL_FUNC constexpr void store(
DstPtrType dst,
StrX str_x,
StrY str_y,
OffX off_x = {},
OffY off_y = {}) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kSubTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kSubTileCols; ++j) {
NAXFrag_t::store(
frag_at(i, j),
dst,
str_x,
str_y,
off_x + i * kFragRows,
off_y + j * kFragCols);
}
}
}
template <
typename SrcPtrType,
typename StrX,
typename StrY,
typename LimX,
typename OffX = Int<0>,
typename OffY = Int<0>>
METAL_FUNC constexpr void load_rows(
SrcPtrType src,
StrX str_x,
StrY str_y,
LimX lim_x,
OffX off_x = {},
OffY off_y = {}) {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kSubTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kSubTileCols; ++j) {
NAXFrag_t::load_rows(
frag_at(i, j),
src,
str_x,
str_y,
lim_x,
off_x + (i * kFragRows),
off_y + (j * kFragCols));
}
}
}
template <
typename SrcPtrType,
typename StrX,
typename StrY,
typename LimX,
typename LimY,
typename OffX = Int<0>,
typename OffY = Int<0>>
METAL_FUNC constexpr void load_safe(
SrcPtrType src,
StrX str_x,
StrY str_y,
LimX lim_x,
LimY lim_y,
OffX off_x = {},
OffY off_y = {}) {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kSubTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kSubTileCols; ++j) {
NAXFrag_t::load_safe(
frag_at(i, j),
src,
str_x,
str_y,
lim_x,
lim_y,
off_x + (i * kFragRows),
off_y + (j * kFragCols));
}
}
}
template <
typename DstPtrType,
typename StrX,
typename StrY,
typename LimX,
typename OffX = Int<0>,
typename OffY = Int<0>>
METAL_FUNC constexpr void store_rows(
DstPtrType dst,
StrX str_x,
StrY str_y,
LimX lim_x,
OffX off_x = {},
OffY off_y = {}) const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kSubTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kSubTileCols; ++j) {
NAXFrag_t::store_safe(
frag_at(i, j),
dst,
str_x,
str_y,
lim_x,
off_x + (i * kFragRows),
off_y + (j * kFragCols));
}
}
}
template <
typename DstPtrType,
typename StrX,
typename StrY,
typename LimX,
typename LimY,
typename OffX = Int<0>,
typename OffY = Int<0>>
METAL_FUNC constexpr void store_safe(
DstPtrType dst,
StrX str_x,
StrY str_y,
LimX lim_x,
LimY lim_y,
OffX off_x = {},
OffY off_y = {}) const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kSubTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kSubTileCols; ++j) {
NAXFrag_t::store_safe(
frag_at(i, j),
dst,
str_x,
str_y,
lim_x,
lim_y,
off_x + (i * kFragRows),
off_y + (j * kFragCols));
}
}
}
template <
typename DstPtrType,
typename StrX,
typename StrY,
typename StartX,
typename StopX,
typename StartY,
typename StopY,
typename OffX = Int<0>,
typename OffY = Int<0>>
METAL_FUNC constexpr void store_slice(
DstPtrType dst,
StrX str_x,
StrY str_y,
StartX start_x,
StopX stop_x,
StartY start_y,
StopY stop_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) const {
const_for_loop<0, kSubTileRows, 1>([&](auto idx_row) {
const_for_loop<0, kSubTileCols, 1>([&](auto idx_col) {
NAXFrag_t::store_slice(
frag_at<idx_row.value, idx_col.value>(),
dst,
str_x,
str_y,
start_x,
stop_x,
start_y,
stop_y,
off_x + idx_row * Int<kFragRows>{},
off_y + idx_col * Int<kFragCols>{});
});
});
}
};
template <
short RC,
short CC,
short RA,
short CA,
short RB,
short CB,
typename CType,
typename AType,
typename BType,
bool transpose_a,
bool transpose_b,
typename NAXFrag_t = BaseNAXFrag>
METAL_FUNC void subtile_matmad_nax(
thread NAXSubTile<CType, RC, CC, NAXFrag_t>& C,
thread NAXSubTile<AType, RA, CA, NAXFrag_t>& A,
metal::bool_constant<transpose_a>,
thread NAXSubTile<BType, RB, CB, NAXFrag_t>& B,
metal::bool_constant<transpose_b>) {
// Static checks
constexpr short FMa = transpose_a ? CA : RA;
constexpr short FMc = RC;
static_assert(FMa == FMc, "NAX matmul: M dimensions do not match");
constexpr short FNb = transpose_b ? RB : CB;
constexpr short FNc = CC;
static_assert(FNb == FNc, "NAX matmul: N dimensions do not match");
constexpr short FKa = transpose_a ? RA : CA;
constexpr short FKb = transpose_b ? CB : RB;
static_assert(FKa == FKb, "NAX matmul: N dimensions do not match");
constexpr short FM = FMc;
constexpr short FN = FNc;
constexpr short FK = FKa;
constexpr int TM = FM / 16;
constexpr int TN = FN / 16;
constexpr int TK = FK / 16;
constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor(
FM,
FN,
FK,
transpose_a,
transpose_b,
true,
mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate);
mpp::tensor_ops::matmul2d<desc, metal::execution_simdgroup> gemm_op;
auto ct_a =
gemm_op.template get_left_input_cooperative_tensor<AType, BType, CType>();
auto ct_b =
gemm_op
.template get_right_input_cooperative_tensor<AType, BType, CType>();
auto ct_c = gemm_op.template get_destination_cooperative_tensor<
decltype(ct_a),
decltype(ct_b),
CType>();
STEEL_PRAGMA_UNROLL
for (short mm = 0; mm < TM; mm++) {
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < TK; kk++) {
const short fi = transpose_a ? kk : mm;
const short fj = transpose_a ? mm : kk;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < 8; i++) {
ct_a[(TK * mm + kk) * 8 + i] = A.frag_at(fi, fj)[i];
}
}
}
STEEL_PRAGMA_UNROLL
for (short nn = 0; nn < TN; nn++) {
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < TK; kk++) {
const short fi = transpose_b ? nn : kk;
const short fj = transpose_b ? kk : nn;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < 8; i++) {
ct_b[(TN * kk + nn) * 8 + i] = B.frag_at(fi, fj)[i];
}
}
}
STEEL_PRAGMA_UNROLL
for (short i = 0; i < ct_c.get_capacity(); i++) {
ct_c[i] = C.elems()[i];
}
gemm_op.run(ct_a, ct_b, ct_c);
STEEL_PRAGMA_UNROLL
for (short i = 0; i < ct_c.get_capacity(); i++) {
C.elems()[i] = ct_c[i];
}
}
template <typename T, short kTileRows_, short kTileCols_, class NAXSubTile_>
struct NAXTile {
using NAXSubTile_t = NAXSubTile_;
using elem_type = T;
STEEL_CONST short kSubTileRows = NAXSubTile_t::kRows;
STEEL_CONST short kSubTileCols = NAXSubTile_t::kCols;
STEEL_CONST short kElemsPerSubTile = NAXSubTile_t::kElemsPerSubTile;
STEEL_CONST short kTileRows = kTileRows_;
STEEL_CONST short kTileCols = kTileCols_;
STEEL_CONST short kRows = kTileRows * kSubTileRows;
STEEL_CONST short kCols = kTileCols * kSubTileCols;
STEEL_CONST short kSubTiles = kTileRows * kTileCols;
STEEL_CONST short kElemsPerTile = kSubTiles * kElemsPerSubTile;
STEEL_CONST short kRowsPerThread = kTileRows * NAXSubTile_t::kRowsPerThread;
STEEL_CONST short kColsPerThread = kTileCols * NAXSubTile_t::kColsPerThread;
STEEL_CONST short kSubTileThrRows = NAXSubTile_t::kRowsPerThread;
STEEL_CONST short kSubTileThrCols = NAXSubTile_t::kColsPerThread;
NAXSubTile_t val_subtiles[kSubTiles];
METAL_FUNC NAXTile() thread {}
METAL_FUNC constexpr void clear() {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kSubTiles; ++i) {
val_subtiles[i].clear();
}
}
METAL_FUNC constexpr thread NAXSubTile_t& subtile_at(
const short i,
const short j) {
return val_subtiles[i * kTileCols + j];
}
METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at(
const short i,
const short j) const {
return val_subtiles[i * kTileCols + j];
}
template <int i, int j>
METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at() const {
return val_subtiles[i * kTileCols + j];
}
METAL_FUNC thread elem_type* elems() {
return reinterpret_cast<thread elem_type*>(val_subtiles[0].elems());
}
METAL_FUNC const thread elem_type* elems() const {
return reinterpret_cast<const thread elem_type*>(val_subtiles[0].elems());
}
template <typename Op>
METAL_FUNC void row_reduce(thread metal::vec<T, kRowsPerThread>& vals) const {
auto sub_rows = (thread metal::vec<T, kSubTileThrRows>*)(&vals);
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
subtile_at(i, j).template row_reduce<Op>(sub_rows[i]);
}
}
}
template <typename Op>
METAL_FUNC void row_bin_op(thread metal::vec<T, kRowsPerThread>& vals) {
auto sub_rows = (thread metal::vec<T, kSubTileThrRows>*)(&vals);
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
subtile_at(i, j).template row_bin_op<Op>(sub_rows[i]);
}
}
}
template <typename U, int str_x, int str_y>
METAL_FUNC void load(const threadgroup U* src) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
subtile_at(i, j).load(
src,
Int<str_x>{},
Int<str_y>{},
i * kSubTileRows,
j * kSubTileCols);
}
}
}
template <typename U, int str_x, int str_y>
METAL_FUNC void store(threadgroup U* dst) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
subtile_at(i, j).store(
dst,
Int<str_x>{},
Int<str_y>{},
i * kSubTileRows,
j * kSubTileCols);
}
}
}
template <typename U>
METAL_FUNC void load(const device U* src, const int ld) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
subtile_at(i, j).load(
&src[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{});
}
}
}
template <typename U>
METAL_FUNC void store(device U* dst, const int ld) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
subtile_at(i, j).store(
&dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{});
}
}
}
template <typename U>
METAL_FUNC void
load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
subtile_at(i, j).load_safe(
src,
ld,
Int<1>{},
src_tile_dims.y,
src_tile_dims.x,
i * kSubTileRows,
j * kSubTileCols);
}
}
}
template <typename U>
METAL_FUNC void
load_rows(const device U* src, const int ld, const short n_rows) {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
subtile_at(i, j).load_rows(
&src[(i * kSubTileRows) * ld + (j * kSubTileCols)],
ld,
Int<1>{},
n_rows - i * kSubTileRows);
}
}
}
template <typename U>
METAL_FUNC void
store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
subtile_at(i, j).store_safe(
dst,
ld,
Int<1>{},
dst_tile_dims.y,
dst_tile_dims.x,
i * kSubTileRows,
j * kSubTileCols);
}
}
}
template <typename U>
METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows)
const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
subtile_at(i, j).store_rows(
&dst[(i * kSubTileRows) * ld + (j * kSubTileCols)],
ld,
Int<1>{},
n_rows - i * kSubTileRows);
}
}
}
template <typename U>
METAL_FUNC void store_slice(
device U* dst,
const int ld,
const short2 start,
const short2 stop) const {
const_for_loop<0, kTileRows, 1>([&](auto idx_row) {
const_for_loop<0, kTileCols, 1>([&](auto idx_col) {
subtile_at<idx_row.value, idx_col.value>().store_slice(
dst,
ld,
Int<1>{},
start.y,
stop.y,
start.x,
stop.x,
idx_row * Int<kSubTileRows>{},
idx_col * Int<kSubTileCols>{});
});
});
}
};
template <
class CTile,
class ATile,
class BTile,
bool transpose_a,
bool transpose_b>
METAL_FUNC void tile_matmad_nax(
thread CTile& C,
thread ATile& A,
metal::bool_constant<transpose_a>,
thread BTile& B,
metal::bool_constant<transpose_b>) {
// Static checks
constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows;
constexpr short TMc = CTile::kTileRows;
static_assert(TMa == TMc, "NAX tile matmul: M dimensions do not match");
constexpr short FMa = transpose_a ? ATile::kSubTileCols : ATile::kSubTileRows;
constexpr short FMc = CTile::kSubTileRows;
static_assert(FMa == FMc, "NAX subtile matmul: M dimensions do not match");
constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols;
constexpr short TNc = CTile::kTileCols;
static_assert(TNb == TNc, "NAX tile matmul: N dimensions do not match");
constexpr short FNb = transpose_b ? BTile::kSubTileRows : BTile::kSubTileCols;
constexpr short FNc = CTile::kSubTileCols;
static_assert(FNb == FNc, "NAX subtile matmul: N dimensions do not match");
constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols;
constexpr short TKb = transpose_b ? BTile::kTileCols : BTile::kTileRows;
static_assert(TKa == TKb, "NAX tile matmul: K dimensions do not match");
constexpr short FKa = transpose_a ? ATile::kSubTileRows : ATile::kSubTileCols;
constexpr short FKb = transpose_b ? BTile::kSubTileCols : BTile::kSubTileRows;
static_assert(FKa == FKb, "NAX subtile matmul: K dimensions do not match");
constexpr short TM = TMc;
constexpr short TN = TNc;
constexpr short TK = TKa;
// Do matmul here
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; ++j) {
STEEL_PRAGMA_UNROLL
for (short k = 0; k < TK; ++k) {
const short ra = transpose_a ? k : i;
const short ca = transpose_a ? i : k;
const short rb = transpose_b ? j : k;
const short cb = transpose_b ? k : j;
subtile_matmad_nax(
C.subtile_at(i, j),
A.subtile_at(ra, ca),
metal::bool_constant<transpose_a>{},
B.subtile_at(rb, cb),
metal::bool_constant<transpose_b>{});
}
}
}
}
} // namespace steel
} // namespace mlx