mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[WIP] Init NAX attention
This commit is contained in:
@@ -137,6 +137,11 @@ if(MLX_ENABLE_NAX)
|
|||||||
build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS})
|
build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS})
|
||||||
build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS})
|
build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS})
|
||||||
|
|
||||||
|
set(STEEL_NAX_ATTN_HEADERS
|
||||||
|
steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h
|
||||||
|
steel/utils/integral_constant.h)
|
||||||
|
|
||||||
|
build_kernel(steel/attn/kernels/steel_attention_nax ${STEEL_NAX_ATTN_HEADERS})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
|
|||||||
@@ -0,0 +1,475 @@
|
|||||||
|
// Copyright © 2024-25 Apple Inc.
|
||||||
|
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// GEMM kernels
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
constant bool align_Q [[function_constant(200)]];
|
||||||
|
constant bool align_K [[function_constant(201)]];
|
||||||
|
|
||||||
|
constant bool has_mask [[function_constant(300)]];
|
||||||
|
constant bool do_causal [[function_constant(301)]];
|
||||||
|
constant bool has_sinks [[function_constant(302)]];
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct TransformScale {
|
||||||
|
T scale;
|
||||||
|
METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
|
||||||
|
|
||||||
|
METAL_FUNC T apply(T x) const {
|
||||||
|
return scale * x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MaxOp {
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||||
|
return metal::max(x, y);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SumOp {
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||||
|
return x + y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MulOp {
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||||
|
return x * y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SubOp {
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||||
|
return x - y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ExpSubOp {
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||||
|
return fast::exp2(x - y);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DivOp {
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||||
|
return x / y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
int BQ,
|
||||||
|
int BK,
|
||||||
|
int BD,
|
||||||
|
int WM,
|
||||||
|
int WN,
|
||||||
|
typename MaskType = float,
|
||||||
|
typename AccumType = float>
|
||||||
|
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention_nax(
|
||||||
|
const device T* Q [[buffer(0)]],
|
||||||
|
const device T* K [[buffer(1)]],
|
||||||
|
const device T* V [[buffer(2)]],
|
||||||
|
device T* O [[buffer(3)]],
|
||||||
|
const constant AttnParams* params [[buffer(4)]],
|
||||||
|
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
|
||||||
|
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
|
||||||
|
const device T* sinks [[buffer(7), function_constant(has_sinks)]],
|
||||||
|
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||||
|
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
|
||||||
|
|
||||||
|
// Pacifying compiler
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
// Move to correct block
|
||||||
|
ulong3 tidl{tid.x, tid.y, tid.z};
|
||||||
|
|
||||||
|
Q += tidl.z * params->Q_strides[0] + // Batch
|
||||||
|
tidl.y * params->Q_strides[1] + // Head
|
||||||
|
tidl.x * BQ * params->Q_strides[2]; // Sequence
|
||||||
|
|
||||||
|
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
|
||||||
|
K += tidl.z * params->K_strides[0] + // Batch
|
||||||
|
kv_head_idx * params->K_strides[1]; // Head
|
||||||
|
|
||||||
|
V += tidl.z * params->V_strides[0] + // Batch
|
||||||
|
kv_head_idx * params->V_strides[1]; // Head
|
||||||
|
|
||||||
|
O += tidl.z * params->O_strides[0] + // Batch
|
||||||
|
tidl.y * params->O_strides[1] + // Head
|
||||||
|
tidl.x * BQ * params->O_strides[2]; // Sequence
|
||||||
|
|
||||||
|
if (has_mask) {
|
||||||
|
mask += tidl.z * mask_params->M_strides[0] + // Batch
|
||||||
|
tidl.y * mask_params->M_strides[1]; // Head
|
||||||
|
}
|
||||||
|
|
||||||
|
const metal::uniform<float> scale2 =
|
||||||
|
make_uniform(params->scale) * make_uniform(1.44269504089f);
|
||||||
|
|
||||||
|
// Prepare MMA tiles
|
||||||
|
constexpr short UQ = 16;
|
||||||
|
constexpr short UD = 32;
|
||||||
|
|
||||||
|
constexpr int kNWarps = WM * WN;
|
||||||
|
static_assert(
|
||||||
|
BQ >= (kNWarps * UQ) && BQ % (kNWarps * UQ) == 0,
|
||||||
|
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
|
||||||
|
|
||||||
|
// Q seq frags per warp
|
||||||
|
constexpr int TQ = BQ / (kNWarps * UQ);
|
||||||
|
// HeadDim frags (all warps load the same frags)
|
||||||
|
constexpr int TD = BD / UD;
|
||||||
|
|
||||||
|
static_assert(TQ == 1, "Check TQ");
|
||||||
|
|
||||||
|
using OSubTile = NAXSubTile<AccumType, UQ, UD>;
|
||||||
|
NAXTile<AccumType, TQ, TD, OSubTile> Otile;
|
||||||
|
|
||||||
|
Otile.clear();
|
||||||
|
|
||||||
|
// Prepare mma tile offsets
|
||||||
|
const short2 simd_coord = OSubTile::NAXFrag_t::get_coord();
|
||||||
|
const short sm = simd_coord.y;
|
||||||
|
const short sn = simd_coord.x;
|
||||||
|
const short tm = UQ * TQ * simd_group_id;
|
||||||
|
|
||||||
|
Q += (tm + sm) * int(params->Q_strides[2]) + sn;
|
||||||
|
K += sm * int(params->K_strides[2]) + sn;
|
||||||
|
V += sm * int(params->V_strides[2]) + sn;
|
||||||
|
|
||||||
|
// Init row reduction variables
|
||||||
|
constexpr short kRowsPT = decltype(Otile)::kRowsPerThread;
|
||||||
|
|
||||||
|
metal::vec<AccumType, kRowsPT> max_score;
|
||||||
|
metal::vec<AccumType, kRowsPT> sum_score{0};
|
||||||
|
|
||||||
|
// Init to -Inf
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
|
max_score[i] = Limits<AccumType>::finite_min;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_sinks) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
|
max_score[i] = M_LOG2E_F * static_cast<AccumType>(sinks[tidl.y]);
|
||||||
|
sum_score[i] = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int kb_lim = params->NK;
|
||||||
|
|
||||||
|
if (do_causal) {
|
||||||
|
int q_max = (tid.x + 1) * BQ + params->qL_off;
|
||||||
|
kb_lim = (q_max + BK - 1) / BK;
|
||||||
|
kb_lim = min(params->NK, kb_lim);
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool is_last_bq = int(tid.x) == (params->NQ_aligned);
|
||||||
|
const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ);
|
||||||
|
const bool is_last_q = is_last_bq;
|
||||||
|
|
||||||
|
const short lim_rows_q = params->qL_rem - (tm + sm);
|
||||||
|
const short lim_rows_k = params->kL_rem - sm;
|
||||||
|
|
||||||
|
// Loop over KV seq length
|
||||||
|
for (int kb = 0; kb < kb_lim; kb++) {
|
||||||
|
const int is_last_k = (kb == (params->NK_aligned));
|
||||||
|
|
||||||
|
// Do S = Q @ K.T
|
||||||
|
constexpr short UDs = 16;
|
||||||
|
constexpr short UKs = 32;
|
||||||
|
|
||||||
|
constexpr short TDs = BD / UDs;
|
||||||
|
constexpr short TKs = BK / UKs;
|
||||||
|
|
||||||
|
using SSubTile = NAXSubTile<AccumType, UQ, UKs>;
|
||||||
|
using QSubTile = NAXSubTile<T, UQ, UDs>;
|
||||||
|
using KSubTile = NAXSubTile<T, UKs, UDs>;
|
||||||
|
|
||||||
|
NAXTile<AccumType, TQ, TKs, SSubTile> Stile;
|
||||||
|
|
||||||
|
Stile.clear();
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short iq = 0; iq < TQ; iq++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short ik = 0; ik < TKs; ik++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short id = 0; id < TDs; id++) {
|
||||||
|
NAXTile<T, 1, 1, QSubTile> Qtile;
|
||||||
|
NAXTile<T, 1, 1, KSubTile> Ktile;
|
||||||
|
|
||||||
|
const int Q_load_off = iq * UQ * int(params->Q_strides[2]) + id * UDs;
|
||||||
|
const int K_load_off =
|
||||||
|
ik * UKs * int(params->K_strides[2]) + id * UDs;
|
||||||
|
|
||||||
|
if (!align_Q && is_last_q) {
|
||||||
|
// Qtile.load_rows(
|
||||||
|
// Q + Q_load_off,
|
||||||
|
// int(params->Q_strides[2]),
|
||||||
|
// lim_rows_q - iq * UQ);
|
||||||
|
Qtile.load_safe(
|
||||||
|
Q + Q_load_off,
|
||||||
|
int(params->Q_strides[2]),
|
||||||
|
short2(BD, lim_rows_q - iq * UQ));
|
||||||
|
} else {
|
||||||
|
Qtile.load(Q + Q_load_off, int(params->Q_strides[2]));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!align_K && is_last_k) {
|
||||||
|
// Ktile.load_rows(
|
||||||
|
// K + K_load_off,
|
||||||
|
// int(params->K_strides[2]),
|
||||||
|
// lim_rows_k - ik * UKs);
|
||||||
|
Ktile.load_safe(
|
||||||
|
K + K_load_off,
|
||||||
|
int(params->K_strides[2]),
|
||||||
|
short2(BD, lim_rows_k - ik * UKs));
|
||||||
|
} else {
|
||||||
|
Ktile.load(K + K_load_off, int(params->K_strides[2]));
|
||||||
|
}
|
||||||
|
|
||||||
|
subtile_matmad_nax(
|
||||||
|
Stile.subtile_at(iq, ik),
|
||||||
|
Qtile.subtile_at(0, 0),
|
||||||
|
metal::false_type{},
|
||||||
|
Ktile.subtile_at(0, 0),
|
||||||
|
metal::true_type{});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scale S
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) {
|
||||||
|
Stile.elems()[ii] *= float(scale2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scale and Retile S
|
||||||
|
constexpr short UK = 16;
|
||||||
|
constexpr short TK = BK / UK;
|
||||||
|
using PSubTile = NAXSubTile<AccumType, UQ, UK>;
|
||||||
|
|
||||||
|
NAXTile<AccumType, TQ, TK, PSubTile> Ptile;
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) {
|
||||||
|
Ptile.elems()[ii] = Stile.elems()[ii];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mask out length sequence
|
||||||
|
if (!align_K && is_last_k) {
|
||||||
|
constexpr auto neg_inf = Limits<AccumType>::finite_min;
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short iq = 0; iq < TQ; iq++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short ik = 0; ik < TK; ik++) {
|
||||||
|
const short col_pos = sn + ik * UK;
|
||||||
|
|
||||||
|
thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) {
|
||||||
|
const auto loc = ii * PSubTile::kFragThrCols + jj;
|
||||||
|
fg[loc] = ((col_pos + jj) >= params->kL_rem) ? neg_inf : fg[loc];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mask out if causal
|
||||||
|
if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) {
|
||||||
|
constexpr auto neg_inf = Limits<AccumType>::finite_min;
|
||||||
|
|
||||||
|
const int base_row = tid.x * BQ + params->qL_off + tm;
|
||||||
|
const int base_col = kb * BK;
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short iq = 0; iq < TQ; iq++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short ik = 0; ik < TK; ik++) {
|
||||||
|
const short row_pos = base_row + iq * UQ;
|
||||||
|
const short col_pos = base_col + ik * UK;
|
||||||
|
|
||||||
|
thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) {
|
||||||
|
const auto r = row_pos + ii * PSubTile::kFragRowsJump + sm;
|
||||||
|
const auto c = col_pos + jj + sn;
|
||||||
|
const auto loc = ii * PSubTile::kFragThrCols + jj;
|
||||||
|
fg[loc] = (r < c) ? neg_inf : fg[loc];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other masking as needed
|
||||||
|
if (has_mask) {
|
||||||
|
constexpr auto neg_inf = Limits<AccumType>::finite_min;
|
||||||
|
|
||||||
|
const int base_row = tid.x * BQ + tm;
|
||||||
|
const int base_col = kb * BK;
|
||||||
|
|
||||||
|
constexpr bool is_bool = is_same_v<MaskType, bool>;
|
||||||
|
using melem_t = typename metal::conditional_t<is_bool, bool, AccumType>;
|
||||||
|
using MSubTile = NAXSubTile<melem_t, UQ, UK>;
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short iq = 0; iq < TQ; iq++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short ik = 0; ik < TK; ik++) {
|
||||||
|
const short row_pos = base_row + iq * UQ + sm;
|
||||||
|
const short col_pos = base_col + ik * UK + sn;
|
||||||
|
|
||||||
|
MSubTile mfrag;
|
||||||
|
mfrag.load_safe(
|
||||||
|
mask,
|
||||||
|
int(mask_params->M_strides[2]),
|
||||||
|
Int<1>{},
|
||||||
|
params->qL,
|
||||||
|
params->kL,
|
||||||
|
row_pos,
|
||||||
|
col_pos);
|
||||||
|
|
||||||
|
thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0);
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short jj = 0; jj < MSubTile::kElemsPerFrag; jj++) {
|
||||||
|
if constexpr (is_bool) {
|
||||||
|
fg[jj] = mfrag.elems()[jj] ? fg[jj] : neg_inf;
|
||||||
|
} else {
|
||||||
|
fg[jj] += M_LOG2E_F * AccumType(mfrag.elems()[jj]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do softmax
|
||||||
|
|
||||||
|
// Temp variables
|
||||||
|
metal::vec<AccumType, kRowsPT> new_max;
|
||||||
|
metal::vec<AccumType, kRowsPT> factor;
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
|
new_max[i] = max_score[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Row max
|
||||||
|
Ptile.template row_reduce<MaxOp>(new_max);
|
||||||
|
|
||||||
|
// exp(Si - rowmax(Si))
|
||||||
|
Ptile.template row_bin_op<ExpSubOp>(new_max);
|
||||||
|
|
||||||
|
// Factor exp(rowmax(Si) - rowmax(Si-1))
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
|
factor[i] = fast::exp2(max_score[i] - new_max[i]);
|
||||||
|
max_score[i] = new_max[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Row Sum
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
|
sum_score[i] = sum_score[i] * factor[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
Ptile.template row_reduce<SumOp>(sum_score);
|
||||||
|
|
||||||
|
// Update O
|
||||||
|
Otile.template row_bin_op<MulOp>(factor);
|
||||||
|
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
// Do O = P @ V
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short iq = 0; iq < TQ; iq++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short id = 0; id < TD; id++) {
|
||||||
|
if constexpr (BD == 128) {
|
||||||
|
if (id == 2) {
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short ik = 0; ik < TK; ik++) {
|
||||||
|
using VSubTile = NAXSubTile<T, UK, UD>;
|
||||||
|
NAXTile<T, 1, 1, VSubTile> Vtile;
|
||||||
|
|
||||||
|
const int V_load_off = ik * UK * int(params->V_strides[2]) + id * UD;
|
||||||
|
|
||||||
|
if (!align_K && is_last_k) {
|
||||||
|
// Vtile.load_rows(
|
||||||
|
// V + V_load_off,
|
||||||
|
// int(params->V_strides[2]),
|
||||||
|
// lim_rows_k - ik * UK);
|
||||||
|
Vtile.load_safe(
|
||||||
|
V + V_load_off,
|
||||||
|
int(params->V_strides[2]),
|
||||||
|
short2(BD, lim_rows_k - ik * UK));
|
||||||
|
} else {
|
||||||
|
Vtile.load(V + V_load_off, int(params->V_strides[2]));
|
||||||
|
}
|
||||||
|
|
||||||
|
subtile_matmad_nax(
|
||||||
|
Otile.subtile_at(iq, id),
|
||||||
|
Ptile.subtile_at(iq, ik),
|
||||||
|
metal::bool_constant<false>{},
|
||||||
|
Vtile.subtile_at(0, 0),
|
||||||
|
metal::bool_constant<false>{});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare for next iteration
|
||||||
|
K += BK * int(params->K_strides[2]);
|
||||||
|
V += BK * int(params->V_strides[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize output
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
|
metal::vec<AccumType, kRowsPT> rcp;
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
|
rcp[i] = (1.f / sum_score[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Otile.template row_bin_op<MulOp>(rcp);
|
||||||
|
|
||||||
|
// Store results
|
||||||
|
O += (tm + sm) * int(params->O_strides[2]) + sn;
|
||||||
|
|
||||||
|
if (!align_Q && is_last_q) {
|
||||||
|
if (lim_rows_q <= 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Otile.store_rows(O, params->O_strides[2], lim_rows_q);
|
||||||
|
Otile.store_safe(O, params->O_strides[2], short2(BD, lim_rows_q));
|
||||||
|
} else {
|
||||||
|
Otile.store(O, int(params->O_strides[2]));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
// Copyright © 2024-25 Apple Inc.
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/attn/nax.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/attn/params.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
|
||||||
|
#include "mlx/backend/metal/kernels/steel/utils.h"
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h"
|
||||||
|
|
||||||
|
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
"steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \
|
||||||
|
"_wm" #wm "_wn" #wn "_mask" #mname, \
|
||||||
|
attention_nax, dtype, bq, bk, bd, wm, wn, mtype, float)
|
||||||
|
|
||||||
|
#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \
|
||||||
|
instantiate_attn(iname, itype, 64, 32, 128, 4, 1, mname, mtype) \
|
||||||
|
instantiate_attn(iname, itype, 64, 32, 64, 4, 1, mname, mtype) \
|
||||||
|
instantiate_attn(iname, itype, 64, 64, 128, 4, 1, mname, mtype) \
|
||||||
|
instantiate_attn(iname, itype, 64, 64, 64, 4, 1, mname, mtype)
|
||||||
|
|
||||||
|
#define instantiate_attn_mask_helper(iname, itype) \
|
||||||
|
instantiate_attn_shapes_helper(iname, itype, iname, itype) \
|
||||||
|
instantiate_attn_shapes_helper(iname, itype, bool_, bool)
|
||||||
|
|
||||||
|
instantiate_attn_mask_helper(float16, half);
|
||||||
|
instantiate_attn_mask_helper(bfloat16, bfloat);
|
||||||
|
|
||||||
|
instantiate_attn_mask_helper(float32, float);
|
||||||
|
// clang-format on
|
||||||
1079
mlx/backend/metal/kernels/steel/attn/nax.h
Normal file
1079
mlx/backend/metal/kernels/steel/attn/nax.h
Normal file
File diff suppressed because it is too large
Load Diff
@@ -12,6 +12,146 @@
|
|||||||
namespace mlx::core::fast {
|
namespace mlx::core::fast {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
#ifdef MLX_ENABLE_NAX
|
||||||
|
|
||||||
|
void sdpa_full_self_attention_nax(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& q,
|
||||||
|
const array& k,
|
||||||
|
const array& v,
|
||||||
|
const float scale,
|
||||||
|
array& o,
|
||||||
|
bool do_causal_,
|
||||||
|
const std::optional<array>& mask,
|
||||||
|
const std::optional<array>& sinks) {
|
||||||
|
using namespace mlx::steel;
|
||||||
|
|
||||||
|
int wm = 4;
|
||||||
|
int wn = 1;
|
||||||
|
|
||||||
|
int bd = q.shape(-1);
|
||||||
|
int bq = 64;
|
||||||
|
int bk = 32;
|
||||||
|
|
||||||
|
int B = q.shape(0);
|
||||||
|
int H = q.shape(1);
|
||||||
|
int D = q.shape(3);
|
||||||
|
int gqa_factor = q.shape(1) / k.shape(1);
|
||||||
|
|
||||||
|
int qL = q.shape(2);
|
||||||
|
int kL = k.shape(2);
|
||||||
|
|
||||||
|
const bool align_Q = (qL % bq) == 0;
|
||||||
|
const bool align_K = (kL % bk) == 0;
|
||||||
|
const bool has_mask = mask.has_value();
|
||||||
|
const bool do_causal = do_causal_;
|
||||||
|
const bool has_sinks = sinks.has_value();
|
||||||
|
|
||||||
|
metal::MTLFCList func_consts = {
|
||||||
|
{&align_Q, MTL::DataType::DataTypeBool, 200},
|
||||||
|
{&align_K, MTL::DataType::DataTypeBool, 201},
|
||||||
|
{&has_mask, MTL::DataType::DataTypeBool, 300},
|
||||||
|
{&do_causal, MTL::DataType::DataTypeBool, 301},
|
||||||
|
{&has_sinks, MTL::DataType::DataTypeBool, 302}};
|
||||||
|
|
||||||
|
std::string base_name;
|
||||||
|
concatenate(
|
||||||
|
base_name,
|
||||||
|
"steel_attention_",
|
||||||
|
type_to_name(q),
|
||||||
|
"_bq",
|
||||||
|
bq,
|
||||||
|
"_bk",
|
||||||
|
bk,
|
||||||
|
"_bd",
|
||||||
|
bd,
|
||||||
|
"_wm",
|
||||||
|
wm,
|
||||||
|
"_wn",
|
||||||
|
wn,
|
||||||
|
"_mask",
|
||||||
|
type_to_name(has_mask ? *mask : q));
|
||||||
|
|
||||||
|
std::string hash_name;
|
||||||
|
concatenate(
|
||||||
|
hash_name,
|
||||||
|
base_name,
|
||||||
|
"_align_Q_",
|
||||||
|
(align_Q ? 't' : 'n'),
|
||||||
|
"_align_K_",
|
||||||
|
(align_K ? 't' : 'n'),
|
||||||
|
"_has_mask_",
|
||||||
|
(has_mask ? 't' : 'n'),
|
||||||
|
"_do_causal_",
|
||||||
|
(do_causal ? 't' : 'n'),
|
||||||
|
"_has_sinks_",
|
||||||
|
(has_sinks ? 't' : 'n'));
|
||||||
|
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
const int NQ = (qL + bq - 1) / bq;
|
||||||
|
const int NK = (kL + bk - 1) / bk;
|
||||||
|
|
||||||
|
const int NQ_aligned = qL / bq;
|
||||||
|
const int NK_aligned = kL / bk;
|
||||||
|
|
||||||
|
AttnParams params{
|
||||||
|
/* int B = */ B,
|
||||||
|
/* int H = */ H,
|
||||||
|
/* int D = */ D,
|
||||||
|
|
||||||
|
/* int qL = */ qL,
|
||||||
|
/* int kL = */ kL,
|
||||||
|
|
||||||
|
/* int gqa_factor = */ gqa_factor,
|
||||||
|
/* float scale = */ scale,
|
||||||
|
|
||||||
|
/* int NQ = */ NQ,
|
||||||
|
/* int NK = */ NK,
|
||||||
|
|
||||||
|
/* int NQ_aligned = */ NQ_aligned,
|
||||||
|
/* int NK_aligned = */ NK_aligned,
|
||||||
|
|
||||||
|
/* int qL_rem = */ (qL - NQ_aligned * bq),
|
||||||
|
/* int kL_rem = */ (kL - NK_aligned * bk),
|
||||||
|
/* int qL_off = */ (kL - qL),
|
||||||
|
|
||||||
|
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
|
||||||
|
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
|
||||||
|
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
|
||||||
|
/* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(q, 0);
|
||||||
|
compute_encoder.set_input_array(k, 1);
|
||||||
|
compute_encoder.set_input_array(v, 2);
|
||||||
|
compute_encoder.set_output_array(o, 3);
|
||||||
|
compute_encoder.set_bytes(params, 4);
|
||||||
|
|
||||||
|
if (has_mask) {
|
||||||
|
auto& m = *mask;
|
||||||
|
|
||||||
|
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
|
||||||
|
m.strides(0), m.strides(1), m.strides(2)}};
|
||||||
|
|
||||||
|
compute_encoder.set_bytes(mask_params, 5);
|
||||||
|
compute_encoder.set_input_array(m, 6);
|
||||||
|
}
|
||||||
|
if (has_sinks) {
|
||||||
|
compute_encoder.set_input_array(*sinks, 7);
|
||||||
|
}
|
||||||
|
|
||||||
|
MTL::Size grid_dims = MTL::Size(NQ, H, B);
|
||||||
|
MTL::Size group_dims = MTL::Size(32, wm, wn);
|
||||||
|
|
||||||
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // MLX_ENABLE_NAX
|
||||||
|
|
||||||
void sdpa_full_self_attention_metal(
|
void sdpa_full_self_attention_metal(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
@@ -23,6 +163,23 @@ void sdpa_full_self_attention_metal(
|
|||||||
bool do_causal_,
|
bool do_causal_,
|
||||||
const std::optional<array>& mask,
|
const std::optional<array>& mask,
|
||||||
const std::optional<array>& sinks) {
|
const std::optional<array>& sinks) {
|
||||||
|
#ifdef MLX_ENABLE_NAX
|
||||||
|
if (metal::is_nax_available() && q.shape(3) != 80 &&
|
||||||
|
(q.dtype() != float32 || env::enable_tf32())) {
|
||||||
|
return sdpa_full_self_attention_nax(
|
||||||
|
/* const Stream& s = */ s,
|
||||||
|
/* metal::Device& d = */ d,
|
||||||
|
/* const array& q = */ q,
|
||||||
|
/* const array& k = */ k,
|
||||||
|
/* const array& v = */ v,
|
||||||
|
/* const float scale = */ scale,
|
||||||
|
/* array& o = */ o,
|
||||||
|
/* bool do_causal_ = */ do_causal_,
|
||||||
|
/* const std::optional<array>& mask = */ mask,
|
||||||
|
/* const std::optional<array>& sinks = */ sinks);
|
||||||
|
}
|
||||||
|
#endif // MLX_ENABLE_NAX
|
||||||
|
|
||||||
using namespace mlx::steel;
|
using namespace mlx::steel;
|
||||||
|
|
||||||
int wm = 4;
|
int wm = 4;
|
||||||
|
|||||||
Reference in New Issue
Block a user