Enabling fused attention for head dim 128 (#1899)

* Share KV smem

* Fix bfloat error

* Unroll O = S @ V loop

* Perf upgrade

* Remove commented out function

* Add -Wno-c++17-extensions flag to metal flags

* Add -Wno-c++17-extensions flag to metal extension flags
This commit is contained in:
Jagrit Digani 2025-02-26 10:02:06 -08:00 committed by GitHub
parent 6bf00ef631
commit 89d327075f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 102 additions and 46 deletions

View File

@ -1,5 +1,7 @@
include(CMakeParseArguments) include(CMakeParseArguments)
# clang format off
#
# ############################################################################## # ##############################################################################
# Build metal library # Build metal library
# #
@ -11,6 +13,8 @@ include(CMakeParseArguments)
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency # of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
# files (like headers) # files (like headers)
# #
# clang format on
macro(mlx_build_metallib) macro(mlx_build_metallib)
# Parse args # Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY) set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
@ -21,7 +25,7 @@ macro(mlx_build_metallib)
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib") set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
# Collect compile options # Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math) set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
# Prepare metallib build command # Prepare metallib build command
add_custom_command( add_custom_command(

View File

@ -9,7 +9,7 @@ set(BASE_HEADERS
utils.h) utils.h)
function(build_kernel_base TARGET SRCFILE DEPS) function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math) set(METAL_FLAGS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources) set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif() endif()

View File

@ -50,7 +50,7 @@ struct SubOp {
struct ExpSubOp { struct ExpSubOp {
template <typename T> template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) { METAL_FUNC static constexpr T apply(T x, T y) {
return fast::exp(x - y); return fast::exp2(x - y);
} }
}; };
@ -103,17 +103,24 @@ template <
tidl.x * BQ * params->O_strides[2]; // Seqeunce tidl.x * BQ * params->O_strides[2]; // Seqeunce
// Prepare threadgroup memory // Prepare threadgroup memory
constexpr short padQ = 0; // 16 / sizeof(T); constexpr short padQ = 16 / sizeof(T);
constexpr short padK = 0; // 16 / sizeof(T); constexpr short padK = 16 / sizeof(T);
constexpr short padV = 0; // 16 / sizeof(T); constexpr short padV = 16 / sizeof(T);
constexpr short LDQ_tgp = BD + padQ; constexpr short LDQ_tgp = BD + padQ;
constexpr short LDK_tgp = BK + padK; constexpr short LDK_tgp = BK + padK;
constexpr short LDV_tgp = BD + padV; constexpr short LDV_tgp = BD + padV;
threadgroup T Qs[BQ * (BD + padQ)]; constexpr short tgp_mem_0 = (BK + padK) * (BD);
threadgroup T Ks[(BK + padK) * BD]; constexpr short tgp_mem_1 = BK * (BD + padV);
threadgroup T Vs[BK * (BD + padV)]; constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1;
threadgroup T Q_smem[BQ * (BD + padQ)];
threadgroup T KV_smem[tgp_mem_s];
threadgroup T* Qs = Q_smem;
threadgroup T* Ks = KV_smem;
threadgroup T* Vs = KV_smem;
// Prepare block loaders // Prepare block loaders
using QBlockLoader = BlockLoaderT< using QBlockLoader = BlockLoaderT<
@ -151,7 +158,7 @@ template <
VBlockLoader loader_v( VBlockLoader loader_v(
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
TransformScale<T> ts(static_cast<T>(params->scale)); TransformScale<T> ts(static_cast<T>(params->scale * 1.44269504089));
// Prepare MMA tiles // Prepare MMA tiles
constexpr short kFragSize = 8; // MMAFrag size constexpr short kFragSize = 8; // MMAFrag size
@ -174,7 +181,7 @@ template <
MMATile<AccumType, TQ, 1, MMAFrag_acc_t> Qtile; MMATile<AccumType, TQ, 1, MMAFrag_acc_t> Qtile;
MMATile<AccumType, 1, TK, MMAFrag_acc_t> Ktile; MMATile<AccumType, 1, TK, MMAFrag_acc_t> Ktile;
MMATile<AccumType, TQ, TK, MMAFrag_acc_t> Stile; MMATile<AccumType, TQ, TK, MMAFrag_acc_t> Stile;
MMATile<AccumType, TK, TD, MMAFrag_acc_t> Vtile; MMATile<AccumType, 1, 1, MMAFrag_acc_t> Vtile;
MMATile<AccumType, TQ, TD, MMAFrag_acc_t> Otile; MMATile<AccumType, TQ, TD, MMAFrag_acc_t> Otile;
Otile.clear(); Otile.clear();
@ -224,11 +231,12 @@ template <
loader_k.load_unsafe(); loader_k.load_unsafe();
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do S = Q @ K.T // Do S = Q @ K.T
Stile.clear(); Stile.clear();
threadgroup_barrier(mem_flags::mem_threadgroup);
STEEL_PRAGMA_UNROLL
for (short dd = 0; dd < TD; dd++) { for (short dd = 0; dd < TD; dd++) {
simdgroup_barrier(mem_flags::mem_none); simdgroup_barrier(mem_flags::mem_none);
@ -264,7 +272,7 @@ template <
} }
} }
simdgroup_barrier(mem_flags::mem_none); threadgroup_barrier(mem_flags::mem_threadgroup);
// Load V blocks // Load V blocks
if (!align_K && kb == (params->NK_aligned)) { if (!align_K && kb == (params->NK_aligned)) {
@ -292,7 +300,7 @@ template <
// Factor exp(rowmax(Si) - rowmax(Si-1)) // Factor exp(rowmax(Si) - rowmax(Si-1))
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) { for (short i = 0; i < kRowsPT; ++i) {
factor[i] = fast::exp(max_score[i] - new_max[i]); factor[i] = fast::exp2(max_score[i] - new_max[i]);
} }
// Save max for next iteration // Save max for next iteration
@ -316,12 +324,35 @@ template <
// Load V into registers // Load V into registers
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
STEEL_PRAGMA_UNROLL
for (short iq = 0; iq < TQ; iq++) {
STEEL_PRAGMA_UNROLL
for (short id = 0; id < TD; id++) {
STEEL_PRAGMA_UNROLL
for (short ik = 0; ik < TK; ik++) {
if constexpr (BD == 128) {
simdgroup_barrier(mem_flags::mem_none); simdgroup_barrier(mem_flags::mem_none);
}
// Do O = S @ V const short kk = ik * kFragSize;
tile_matmad(Otile, Stile, Vtile, Otile); const short dd = id * kFragSize;
Vtile.template load<T, 1, 1, LDV_tgp, 1>(
&Vs[Vs_offset + kk * LDV_tgp + dd]);
if constexpr (BD == 128) {
simdgroup_barrier(mem_flags::mem_none);
}
MMAFrag_acc_t::mma(
Otile.frag_at(iq, id),
Stile.frag_at(iq, ik),
Vtile.frag_at(0, 0),
Otile.frag_at(iq, id));
}
}
}
// Prepare for next iteration // Prepare for next iteration
loader_k.next(); loader_k.next();

View File

@ -62,6 +62,12 @@ struct BaseMMAFrag<T, 8, 8> {
typedef metal::vec<T, kElemRows> row_frag_type; typedef metal::vec<T, kElemRows> row_frag_type;
typedef metal::vec<T, kElemCols> col_frag_type; typedef metal::vec<T, kElemCols> col_frag_type;
template <typename U>
using dtype_mat_t = typename metal::simdgroup_matrix<U, kFragRows, kFragCols>;
template <typename U>
using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
[[thread_index_in_simdgroup]]) { [[thread_index_in_simdgroup]]) {
const short qid = simd_lane_id / 4; const short qid = simd_lane_id / 4;
@ -158,30 +164,32 @@ struct BaseMMAFrag<T, 8, 8> {
} }
} }
template <typename Atype, typename Btype, typename Ctype>
METAL_FUNC static constexpr void mma( METAL_FUNC static constexpr void mma(
thread frag_type& D, thread frag_type& D,
thread frag_type& A, thread dtype_frag_t<Atype>& A,
thread frag_type& B, thread dtype_frag_t<Btype>& B,
thread frag_type& C) { thread dtype_frag_t<Ctype>& C) {
mat_type D_mat; mat_type D_mat;
mat_type A_mat; dtype_mat_t<Atype> A_mat;
mat_type B_mat; dtype_mat_t<Btype> B_mat;
mat_type C_mat; dtype_mat_t<Ctype> C_mat;
reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A; reinterpret_cast<thread dtype_frag_t<Atype>&>(A_mat.thread_elements()) = A;
reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B; reinterpret_cast<thread dtype_frag_t<Btype>&>(B_mat.thread_elements()) = B;
reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C; reinterpret_cast<thread dtype_frag_t<Ctype>&>(C_mat.thread_elements()) = C;
mma(D_mat, A_mat, B_mat, C_mat); mma(D_mat, A_mat, B_mat, C_mat);
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements()); D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
} }
template <typename Atype, typename Btype, typename Ctype>
METAL_FUNC static constexpr void mma( METAL_FUNC static constexpr void mma(
thread mat_type& D, thread mat_type& D,
thread mat_type& A, thread dtype_mat_t<Atype>& A,
thread mat_type& B, thread dtype_mat_t<Btype>& B,
thread mat_type& C) { thread dtype_mat_t<Ctype>& C) {
simdgroup_multiply_accumulate(D, A, B, C); simdgroup_multiply_accumulate(D, A, B, C);
} }
@ -242,7 +250,7 @@ struct MMATile {
typedef typename MMAFrag_t::mat_type mat_type; typedef typename MMAFrag_t::mat_type mat_type;
typedef typename MMAFrag_t::frag_type frag_type; typedef typename MMAFrag_t::frag_type frag_type;
frag_type val_frags[kNumFrags] = {frag_type(0)}; frag_type val_frags[kNumFrags]; // = {frag_type(0)};
METAL_FUNC MMATile() thread {} METAL_FUNC MMATile() thread {}
@ -409,24 +417,37 @@ struct MMATile {
} }
}; };
template <typename T, typename U, int M, int N, int K> template <
typename Dtype,
typename Atype,
typename Btype,
typename Ctype,
int M,
int N,
int K,
class MMAFragD,
class MMAFragA,
class MMAFragB,
class MMAFragC>
METAL_FUNC void tile_matmad( METAL_FUNC void tile_matmad(
thread MMATile<T, M, N>& D, thread MMATile<Dtype, M, N, MMAFragD>& D,
thread MMATile<U, M, K>& A, thread MMATile<Atype, M, K, MMAFragA>& A,
thread MMATile<U, K, N>& B, thread MMATile<Btype, K, N, MMAFragB>& B,
thread MMATile<T, M, N>& C) { thread MMATile<Ctype, M, N, MMAFragC>& C) {
STEEL_PRAGMA_UNROLL
for (short k = 0; k < K; ++k) {
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short m = 0; m < M; ++m) { for (short m = 0; m < M; ++m) {
STEEL_PRAGMA_UNROLL STEEL_PRAGMA_UNROLL
for (short n = 0; n < N; ++n) { for (short n = 0; n < N; ++n) {
short m_serp = m; //(n % 2) ? (M - 1 - m) : m;
short n_serp = (m % 2) ? (N - 1 - n) : n; short n_serp = (m % 2) ? (N - 1 - n) : n;
MMATile<T, M, N>::MMAFrag_t::mma(
D.frag_at(m, n_serp), STEEL_PRAGMA_UNROLL
A.frag_at(m, k), for (short k = 0; k < K; ++k) {
MMAFragD::mma(
D.frag_at(m_serp, n_serp),
A.frag_at(m_serp, k),
B.frag_at(k, n_serp), B.frag_at(k, n_serp),
C.frag_at(m, n_serp)); C.frag_at(m_serp, n_serp));
} }
} }
} }

View File

@ -693,7 +693,7 @@ array scaled_dot_product_attention(
query_head_dim == value_head_dim && query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128); (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80); (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask && const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
sdpa_full_supported_head_dim && stream.device == Device::gpu; sdpa_full_supported_head_dim && stream.device == Device::gpu;