mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Enabling fused attention for head dim 128 (#1899)
* Share KV smem * Fix bfloat error * Unroll O = S @ V loop * Perf upgrade * Remove commented out function * Add -Wno-c++17-extensions flag to metal flags * Add -Wno-c++17-extensions flag to metal extension flags
This commit is contained in:
parent
6bf00ef631
commit
89d327075f
@ -1,5 +1,7 @@
|
|||||||
include(CMakeParseArguments)
|
include(CMakeParseArguments)
|
||||||
|
|
||||||
|
# clang format off
|
||||||
|
#
|
||||||
# ##############################################################################
|
# ##############################################################################
|
||||||
# Build metal library
|
# Build metal library
|
||||||
#
|
#
|
||||||
@ -11,6 +13,8 @@ include(CMakeParseArguments)
|
|||||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||||
# files (like headers)
|
# files (like headers)
|
||||||
#
|
#
|
||||||
|
# clang format on
|
||||||
|
|
||||||
macro(mlx_build_metallib)
|
macro(mlx_build_metallib)
|
||||||
# Parse args
|
# Parse args
|
||||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||||
@ -21,7 +25,7 @@ macro(mlx_build_metallib)
|
|||||||
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
||||||
|
|
||||||
# Collect compile options
|
# Collect compile options
|
||||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||||
|
|
||||||
# Prepare metallib build command
|
# Prepare metallib build command
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
|
@ -9,7 +9,7 @@ set(BASE_HEADERS
|
|||||||
utils.h)
|
utils.h)
|
||||||
|
|
||||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||||
if(MLX_METAL_DEBUG)
|
if(MLX_METAL_DEBUG)
|
||||||
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
|
||||||
endif()
|
endif()
|
||||||
|
@ -50,7 +50,7 @@ struct SubOp {
|
|||||||
struct ExpSubOp {
|
struct ExpSubOp {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
METAL_FUNC static constexpr T apply(T x, T y) {
|
METAL_FUNC static constexpr T apply(T x, T y) {
|
||||||
return fast::exp(x - y);
|
return fast::exp2(x - y);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -103,17 +103,24 @@ template <
|
|||||||
tidl.x * BQ * params->O_strides[2]; // Seqeunce
|
tidl.x * BQ * params->O_strides[2]; // Seqeunce
|
||||||
|
|
||||||
// Prepare threadgroup memory
|
// Prepare threadgroup memory
|
||||||
constexpr short padQ = 0; // 16 / sizeof(T);
|
constexpr short padQ = 16 / sizeof(T);
|
||||||
constexpr short padK = 0; // 16 / sizeof(T);
|
constexpr short padK = 16 / sizeof(T);
|
||||||
constexpr short padV = 0; // 16 / sizeof(T);
|
constexpr short padV = 16 / sizeof(T);
|
||||||
|
|
||||||
constexpr short LDQ_tgp = BD + padQ;
|
constexpr short LDQ_tgp = BD + padQ;
|
||||||
constexpr short LDK_tgp = BK + padK;
|
constexpr short LDK_tgp = BK + padK;
|
||||||
constexpr short LDV_tgp = BD + padV;
|
constexpr short LDV_tgp = BD + padV;
|
||||||
|
|
||||||
threadgroup T Qs[BQ * (BD + padQ)];
|
constexpr short tgp_mem_0 = (BK + padK) * (BD);
|
||||||
threadgroup T Ks[(BK + padK) * BD];
|
constexpr short tgp_mem_1 = BK * (BD + padV);
|
||||||
threadgroup T Vs[BK * (BD + padV)];
|
constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1;
|
||||||
|
|
||||||
|
threadgroup T Q_smem[BQ * (BD + padQ)];
|
||||||
|
threadgroup T KV_smem[tgp_mem_s];
|
||||||
|
|
||||||
|
threadgroup T* Qs = Q_smem;
|
||||||
|
threadgroup T* Ks = KV_smem;
|
||||||
|
threadgroup T* Vs = KV_smem;
|
||||||
|
|
||||||
// Prepare block loaders
|
// Prepare block loaders
|
||||||
using QBlockLoader = BlockLoaderT<
|
using QBlockLoader = BlockLoaderT<
|
||||||
@ -151,7 +158,7 @@ template <
|
|||||||
VBlockLoader loader_v(
|
VBlockLoader loader_v(
|
||||||
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
|
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
|
||||||
|
|
||||||
TransformScale<T> ts(static_cast<T>(params->scale));
|
TransformScale<T> ts(static_cast<T>(params->scale * 1.44269504089));
|
||||||
|
|
||||||
// Prepare MMA tiles
|
// Prepare MMA tiles
|
||||||
constexpr short kFragSize = 8; // MMAFrag size
|
constexpr short kFragSize = 8; // MMAFrag size
|
||||||
@ -174,7 +181,7 @@ template <
|
|||||||
MMATile<AccumType, TQ, 1, MMAFrag_acc_t> Qtile;
|
MMATile<AccumType, TQ, 1, MMAFrag_acc_t> Qtile;
|
||||||
MMATile<AccumType, 1, TK, MMAFrag_acc_t> Ktile;
|
MMATile<AccumType, 1, TK, MMAFrag_acc_t> Ktile;
|
||||||
MMATile<AccumType, TQ, TK, MMAFrag_acc_t> Stile;
|
MMATile<AccumType, TQ, TK, MMAFrag_acc_t> Stile;
|
||||||
MMATile<AccumType, TK, TD, MMAFrag_acc_t> Vtile;
|
MMATile<AccumType, 1, 1, MMAFrag_acc_t> Vtile;
|
||||||
MMATile<AccumType, TQ, TD, MMAFrag_acc_t> Otile;
|
MMATile<AccumType, TQ, TD, MMAFrag_acc_t> Otile;
|
||||||
|
|
||||||
Otile.clear();
|
Otile.clear();
|
||||||
@ -224,11 +231,12 @@ template <
|
|||||||
loader_k.load_unsafe();
|
loader_k.load_unsafe();
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Do S = Q @ K.T
|
// Do S = Q @ K.T
|
||||||
Stile.clear();
|
Stile.clear();
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short dd = 0; dd < TD; dd++) {
|
for (short dd = 0; dd < TD; dd++) {
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
@ -264,7 +272,7 @@ template <
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// Load V blocks
|
// Load V blocks
|
||||||
if (!align_K && kb == (params->NK_aligned)) {
|
if (!align_K && kb == (params->NK_aligned)) {
|
||||||
@ -292,7 +300,7 @@ template <
|
|||||||
// Factor exp(rowmax(Si) - rowmax(Si-1))
|
// Factor exp(rowmax(Si) - rowmax(Si-1))
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short i = 0; i < kRowsPT; ++i) {
|
for (short i = 0; i < kRowsPT; ++i) {
|
||||||
factor[i] = fast::exp(max_score[i] - new_max[i]);
|
factor[i] = fast::exp2(max_score[i] - new_max[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save max for next iteration
|
// Save max for next iteration
|
||||||
@ -316,12 +324,35 @@ template <
|
|||||||
|
|
||||||
// Load V into registers
|
// Load V into registers
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
|
|
||||||
|
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short iq = 0; iq < TQ; iq++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short id = 0; id < TD; id++) {
|
||||||
|
STEEL_PRAGMA_UNROLL
|
||||||
|
for (short ik = 0; ik < TK; ik++) {
|
||||||
|
if constexpr (BD == 128) {
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
|
||||||
// Do O = S @ V
|
const short kk = ik * kFragSize;
|
||||||
tile_matmad(Otile, Stile, Vtile, Otile);
|
const short dd = id * kFragSize;
|
||||||
|
|
||||||
|
Vtile.template load<T, 1, 1, LDV_tgp, 1>(
|
||||||
|
&Vs[Vs_offset + kk * LDV_tgp + dd]);
|
||||||
|
|
||||||
|
if constexpr (BD == 128) {
|
||||||
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
|
||||||
|
MMAFrag_acc_t::mma(
|
||||||
|
Otile.frag_at(iq, id),
|
||||||
|
Stile.frag_at(iq, ik),
|
||||||
|
Vtile.frag_at(0, 0),
|
||||||
|
Otile.frag_at(iq, id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Prepare for next iteration
|
// Prepare for next iteration
|
||||||
loader_k.next();
|
loader_k.next();
|
||||||
|
@ -62,6 +62,12 @@ struct BaseMMAFrag<T, 8, 8> {
|
|||||||
typedef metal::vec<T, kElemRows> row_frag_type;
|
typedef metal::vec<T, kElemRows> row_frag_type;
|
||||||
typedef metal::vec<T, kElemCols> col_frag_type;
|
typedef metal::vec<T, kElemCols> col_frag_type;
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
using dtype_mat_t = typename metal::simdgroup_matrix<U, kFragRows, kFragCols>;
|
||||||
|
|
||||||
|
template <typename U>
|
||||||
|
using dtype_frag_t = typename metal::vec<U, kElemsPerFrag>;
|
||||||
|
|
||||||
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
|
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
|
||||||
[[thread_index_in_simdgroup]]) {
|
[[thread_index_in_simdgroup]]) {
|
||||||
const short qid = simd_lane_id / 4;
|
const short qid = simd_lane_id / 4;
|
||||||
@ -158,30 +164,32 @@ struct BaseMMAFrag<T, 8, 8> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Atype, typename Btype, typename Ctype>
|
||||||
METAL_FUNC static constexpr void mma(
|
METAL_FUNC static constexpr void mma(
|
||||||
thread frag_type& D,
|
thread frag_type& D,
|
||||||
thread frag_type& A,
|
thread dtype_frag_t<Atype>& A,
|
||||||
thread frag_type& B,
|
thread dtype_frag_t<Btype>& B,
|
||||||
thread frag_type& C) {
|
thread dtype_frag_t<Ctype>& C) {
|
||||||
mat_type D_mat;
|
mat_type D_mat;
|
||||||
mat_type A_mat;
|
dtype_mat_t<Atype> A_mat;
|
||||||
mat_type B_mat;
|
dtype_mat_t<Btype> B_mat;
|
||||||
mat_type C_mat;
|
dtype_mat_t<Ctype> C_mat;
|
||||||
|
|
||||||
reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
|
reinterpret_cast<thread dtype_frag_t<Atype>&>(A_mat.thread_elements()) = A;
|
||||||
reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
|
reinterpret_cast<thread dtype_frag_t<Btype>&>(B_mat.thread_elements()) = B;
|
||||||
reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
|
reinterpret_cast<thread dtype_frag_t<Ctype>&>(C_mat.thread_elements()) = C;
|
||||||
|
|
||||||
mma(D_mat, A_mat, B_mat, C_mat);
|
mma(D_mat, A_mat, B_mat, C_mat);
|
||||||
|
|
||||||
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
|
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Atype, typename Btype, typename Ctype>
|
||||||
METAL_FUNC static constexpr void mma(
|
METAL_FUNC static constexpr void mma(
|
||||||
thread mat_type& D,
|
thread mat_type& D,
|
||||||
thread mat_type& A,
|
thread dtype_mat_t<Atype>& A,
|
||||||
thread mat_type& B,
|
thread dtype_mat_t<Btype>& B,
|
||||||
thread mat_type& C) {
|
thread dtype_mat_t<Ctype>& C) {
|
||||||
simdgroup_multiply_accumulate(D, A, B, C);
|
simdgroup_multiply_accumulate(D, A, B, C);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -242,7 +250,7 @@ struct MMATile {
|
|||||||
typedef typename MMAFrag_t::mat_type mat_type;
|
typedef typename MMAFrag_t::mat_type mat_type;
|
||||||
typedef typename MMAFrag_t::frag_type frag_type;
|
typedef typename MMAFrag_t::frag_type frag_type;
|
||||||
|
|
||||||
frag_type val_frags[kNumFrags] = {frag_type(0)};
|
frag_type val_frags[kNumFrags]; // = {frag_type(0)};
|
||||||
|
|
||||||
METAL_FUNC MMATile() thread {}
|
METAL_FUNC MMATile() thread {}
|
||||||
|
|
||||||
@ -409,24 +417,37 @@ struct MMATile {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, int M, int N, int K>
|
template <
|
||||||
|
typename Dtype,
|
||||||
|
typename Atype,
|
||||||
|
typename Btype,
|
||||||
|
typename Ctype,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
class MMAFragD,
|
||||||
|
class MMAFragA,
|
||||||
|
class MMAFragB,
|
||||||
|
class MMAFragC>
|
||||||
METAL_FUNC void tile_matmad(
|
METAL_FUNC void tile_matmad(
|
||||||
thread MMATile<T, M, N>& D,
|
thread MMATile<Dtype, M, N, MMAFragD>& D,
|
||||||
thread MMATile<U, M, K>& A,
|
thread MMATile<Atype, M, K, MMAFragA>& A,
|
||||||
thread MMATile<U, K, N>& B,
|
thread MMATile<Btype, K, N, MMAFragB>& B,
|
||||||
thread MMATile<T, M, N>& C) {
|
thread MMATile<Ctype, M, N, MMAFragC>& C) {
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short k = 0; k < K; ++k) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short m = 0; m < M; ++m) {
|
for (short m = 0; m < M; ++m) {
|
||||||
STEEL_PRAGMA_UNROLL
|
STEEL_PRAGMA_UNROLL
|
||||||
for (short n = 0; n < N; ++n) {
|
for (short n = 0; n < N; ++n) {
|
||||||
|
short m_serp = m; //(n % 2) ? (M - 1 - m) : m;
|
||||||
short n_serp = (m % 2) ? (N - 1 - n) : n;
|
short n_serp = (m % 2) ? (N - 1 - n) : n;
|
||||||
MMATile<T, M, N>::MMAFrag_t::mma(
|
|
||||||
D.frag_at(m, n_serp),
|
STEEL_PRAGMA_UNROLL
|
||||||
A.frag_at(m, k),
|
for (short k = 0; k < K; ++k) {
|
||||||
|
MMAFragD::mma(
|
||||||
|
D.frag_at(m_serp, n_serp),
|
||||||
|
A.frag_at(m_serp, k),
|
||||||
B.frag_at(k, n_serp),
|
B.frag_at(k, n_serp),
|
||||||
C.frag_at(m, n_serp));
|
C.frag_at(m_serp, n_serp));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -693,7 +693,7 @@ array scaled_dot_product_attention(
|
|||||||
query_head_dim == value_head_dim &&
|
query_head_dim == value_head_dim &&
|
||||||
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
|
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
|
||||||
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
|
||||||
(query_head_dim == 64 || query_head_dim == 80);
|
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
|
||||||
|
|
||||||
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
|
const bool supports_sdpa_full = query_sequence_length >= threshold && !mask &&
|
||||||
sdpa_full_supported_head_dim && stream.device == Device::gpu;
|
sdpa_full_supported_head_dim && stream.device == Device::gpu;
|
||||||
|
Loading…
Reference in New Issue
Block a user