mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Enabling fused attention for head dim 128 (#1899)
* Share KV smem * Fix bfloat error * Unroll O = S @ V loop * Perf upgrade * Remove commented out function * Add -Wno-c++17-extensions flag to metal flags * Add -Wno-c++17-extensions flag to metal extension flags
This commit is contained in:
parent
6bf00ef631
commit
89d327075f
@ -1,5 +1,7 @@
|
||||
include(CMakeParseArguments)
|
||||
|
||||
# clang format off
|
||||
#
|
||||
# ##############################################################################
|
||||
# Build metal library
|
||||
#
|
||||
@ -11,6 +13,8 @@ include(CMakeParseArguments)
|
||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||
# files (like headers)
|
||||
#
|
||||
# clang format on
|
||||
|
||||
macro(mlx_build_metallib)
|
||||
# Parse args
|
||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||
@ -21,7 +25,7 @@ macro(mlx_build_metallib)
|
||||
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
||||
|
||||
# Collect compile options
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||
|
||||
# Prepare metallib build command
|
||||
add_custom_command(
|
||||
|
@ -9,7 +9,7 @@ set(BASE_HEADERS
|
||||
utils.h)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
||||
endif()
|
||||
|
@ -50,7 +50,7 @@ struct SubOp {
|
||||
struct ExpSubOp {
|
||||
template <typename T>
|
||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||
return fast::exp(x - y);
|
||||
return fast::exp2(x - y);
|
||||
}
|
||||
};
|
||||
|
||||
@ -103,17 +103,24 @@ template <
|
||||
tidl.x * BQ * params->O_strides[2]; // Seqeunce
|
||||
|
||||
// Prepare threadgroup memory
|
||||
constexpr short padQ = 0; // 16 / sizeof(T);
|
||||
constexpr short padK = 0; // 16 / sizeof(T);
|
||||
constexpr short padV = 0; // 16 / sizeof(T);
|
||||
constexpr short padQ = 16 / sizeof(T);
|
||||
constexpr short padK = 16 / sizeof(T);
|
||||
constexpr short padV = 16 / sizeof(T);
|
||||
|
||||
constexpr short LDQ_tgp = BD + padQ;
|
||||
constexpr short LDK_tgp = BK + padK;
|
||||
constexpr short LDV_tgp = BD + padV;
|
||||
|
||||
threadgroup T Qs[BQ * (BD + padQ)];
|
||||
threadgroup T Ks[(BK + padK) * BD];
|
||||
threadgroup T Vs[BK * (BD + padV)];
|
||||
constexpr short tgp_mem_0 = (BK + padK) * (BD);
|
||||
constexpr short tgp_mem_1 = BK * (BD + padV);
|
||||
constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1;
|
||||
|
||||
threadgroup T Q_smem[BQ * (BD + padQ)];
|
||||
threadgroup T KV_smem[tgp_mem_s];
|
||||
|
||||
threadgroup T* Qs = Q_smem;
|
||||
threadgroup T* Ks = KV_smem;
|
||||
threadgroup T* Vs = KV_smem;
|
||||
|
||||
// Prepare block loaders
|
||||
using QBlockLoader = BlockLoaderT<
|
||||
@ -151,7 +158,7 @@ template <
|
||||
VBlockLoader loader_v(
|
||||
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
|
||||
|
||||
TransformScale<T> ts(static_cast<T>(params->scale));
|
||||
TransformScale<T> ts(static_cast<T>(params->scale * 1.44269504089));
|
||||
|
||||
// Prepare MMA tiles
|
||||
constexpr short kFragSize = 8; // MMAFrag size
|
||||
@ -174,7 +181,7 @@ template <
|
||||
MMATile<AccumType, TQ, 1, MMAFrag_acc_t> Qtile;
|
||||
MMATile<AccumType, 1, TK, MMAFrag_acc_t> Ktile;
|
||||
MMATile<AccumType, TQ, TK, MMAFrag_acc_t> Stile;
|
||||
MMATile<AccumType, TK, TD, MMAFrag_acc_t> Vtile;
|
||||
MMATile<AccumType, 1, 1, MMAFrag_acc_t> Vtile;
|
||||
MMATile<AccumType, TQ, TD, MMAFrag_acc_t> Otile;
|
||||
|
||||
Otile.clear();
|
||||
@ -224,11 +231,12 @@ template <
|
||||
loader_k.load_unsafe();
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Do S = Q @ K.T
|
||||
Stile.clear();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short dd = 0; dd < TD; dd++) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
@ -264,7 +272,7 @@ template <
|
||||
}
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load V blocks
|
||||
if (!align_K && kb == (params->NK_aligned)) {
|
||||
@ -292,7 +300,7 @@ template <
|
||||
// Factor exp(rowmax(Si) - rowmax(Si-1))
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < kRowsPT; ++i) {
|
||||
factor[i] = fast::exp(max_score[i] - new_max[i]);
|
||||
factor[i] = fast::exp2(max_score[i] - new_max[i]);
|
||||
}
|
||||
|
||||
// Save max for next iteration
|
||||
@ -316,12 +324,35 @@ template <
|
||||
|
||||
// Load V into registers
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short iq = 0; iq < TQ; iq++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short id = 0; id < TD; id++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short ik = 0; ik < TK; ik++) {
|
||||
if constexpr (BD == 128) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
// Do O = S @ V
|
||||
tile_matmad(Otile, Stile, Vtile, Otile);
|
||||
const short kk = ik * kFragSize;
|
||||
const short dd = id * kFragSize;
|
||||
|
||||
Vtile.template load<T, 1, 1, LDV_tgp, 1>(
|
||||
&Vs[Vs_offset + kk * LDV_tgp + dd]);
|
||||
|
||||
if constexpr (BD == 128) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
MMAFrag_acc_t::mma(
|
||||
Otile.frag_at(iq, id),
|
||||
Stile.frag_at(iq, ik),
|
||||
Vtile.frag_at(0, 0),
|
||||
Otile.frag_at(iq, id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare for next iteration
|
||||
loader_k.next();
|
||||
|
@ -62,6 +62,12 @@ struct BaseMMAFrag<T, 8, 8> {
|
||||
typedef metal::vec<T, kElemRows> row_frag_type;
|
||||
typedef metal::vec<T, kElemCols> col_frag_type;
|
||||
|
||||
template <typename U>
|
||||
using dtype_mat_t = typename metal::simdgroup_matrix<U, kFragRows, kFragCols>;
|
||||
|
||||
template <typename U>
|
||||
using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;
|
||||
|
||||
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
|
||||
[[thread_index_in_simdgroup]]) {
|
||||
const short qid = simd_lane_id / 4;
|
||||
@ -158,30 +164,32 @@ struct BaseMMAFrag<T, 8, 8> {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Atype, typename Btype, typename Ctype>
|
||||
METAL_FUNC static constexpr void mma(
|
||||
thread frag_type& D,
|
||||
thread frag_type& A,
|
||||
thread frag_type& B,
|
||||
thread frag_type& C) {
|
||||
thread dtype_frag_t<Atype>& A,
|
||||
thread dtype_frag_t<Btype>& B,
|
||||
thread dtype_frag_t<Ctype>& C) {
|
||||
mat_type D_mat;
|
||||
mat_type A_mat;
|
||||
mat_type B_mat;
|
||||
mat_type C_mat;
|
||||
dtype_mat_t<Atype> A_mat;
|
||||
dtype_mat_t<Btype> B_mat;
|
||||
dtype_mat_t<Ctype> C_mat;
|
||||
|
||||
reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
|
||||
reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
|
||||
reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
|
||||
reinterpret_cast<thread dtype_frag_t<Atype>&>(A_mat.thread_elements()) = A;
|
||||
reinterpret_cast<thread dtype_frag_t<Btype>&>(B_mat.thread_elements()) = B;
|
||||
reinterpret_cast<thread dtype_frag_t<Ctype>&>(C_mat.thread_elements()) = C;
|
||||
|
||||
mma(D_mat, A_mat, B_mat, C_mat);
|
||||
|
||||
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
|
||||
}
|
||||
|
||||
template <typename Atype, typename Btype, typename Ctype>
|
||||
METAL_FUNC static constexpr void mma(
|
||||
thread mat_type& D,
|
||||
thread mat_type& A,
|
||||
thread mat_type& B,
|
||||
thread mat_type& C) {
|
||||
thread dtype_mat_t<Atype>& A,
|
||||
thread dtype_mat_t<Btype>& B,
|
||||
thread dtype_mat_t<Ctype>& C) {
|
||||
simdgroup_multiply_accumulate(D, A, B, C);
|
||||
}
|
||||
|
||||
@ -242,7 +250,7 @@ struct MMATile {
|
||||
typedef typename MMAFrag_t::mat_type mat_type;
|
||||
typedef typename MMAFrag_t::frag_type frag_type;
|
||||
|
||||
frag_type val_frags[kNumFrags] = {frag_type(0)};
|
||||
frag_type val_frags[kNumFrags]; // = {frag_type(0)};
|
||||
|
||||
METAL_FUNC MMATile() thread {}
|
||||
|
||||
@ -409,24 +417,37 @@ struct MMATile {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, int M, int N, int K>
|
||||
template <
|
||||
typename Dtype,
|
||||
typename Atype,
|
||||
typename Btype,
|
||||
typename Ctype,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
class MMAFragD,
|
||||
class MMAFragA,
|
||||
class MMAFragB,
|
||||
class MMAFragC>
|
||||
METAL_FUNC void tile_matmad(
|
||||
thread MMATile<T, M, N>& D,
|
||||
thread MMATile<U, M, K>& A,
|
||||
thread MMATile<U, K, N>& B,
|
||||
thread MMATile<T, M, N>& C) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < K; ++k) {
|
||||
thread MMATile<Dtype, M, N, MMAFragD>& D,
|
||||
thread MMATile<Atype, M, K, MMAFragA>& A,
|
||||
thread MMATile<Btype, K, N, MMAFragB>& B,
|
||||
thread MMATile<Ctype, M, N, MMAFragC>& C) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short m = 0; m < M; ++m) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short n = 0; n < N; ++n) {
|
||||
short m_serp = m; //(n % 2) ? (M - 1 - m) : m;
|
||||
short n_serp = (m % 2) ? (N - 1 - n) : n;
|
||||
MMATile<T, M, N>::MMAFrag_t::mma(
|
||||
D.frag_at(m, n_serp),
|
||||
A.frag_at(m, k),
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short k = 0; k < K; ++k) {
|
||||
MMAFragD::mma(
|
||||
D.frag_at(m_serp, n_serp),
|
||||
A.frag_at(m_serp, k),
|
||||
B.frag_at(k, n_serp),
|
||||
C.frag_at(m, n_serp));
|
||||
C.frag_at(m_serp, n_serp));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -693,7 +693,7 @@ array scaled_dot_product_attention(
|
||||
query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
|
||||
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
||||
(query_head_dim == 64 || query_head_dim == 80);
|
||||
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||
|
||||
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
|
||||
sdpa_full_supported_head_dim && stream.device == Device::gpu;
|
||||
|
Loading…
Reference in New Issue
Block a user