diff --git a/cmake/extension.cmake b/cmake/extension.cmake index 6f2354897..3270b0056 100644 --- a/cmake/extension.cmake +++ b/cmake/extension.cmake @@ -1,5 +1,7 @@ include(CMakeParseArguments) +# clang format off +# # ############################################################################## # Build metal library # @@ -11,6 +13,8 @@ include(CMakeParseArguments) # of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency # files (like headers) # +# clang format on + macro(mlx_build_metallib) # Parse args set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY) @@ -21,7 +25,7 @@ macro(mlx_build_metallib) set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib") # Collect compile options - set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math) + set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions) # Prepare metallib build command add_custom_command( diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 2951188e9..f7dae3121 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -9,7 +9,7 @@ set(BASE_HEADERS utils.h) function(build_kernel_base TARGET SRCFILE DEPS) - set(METAL_FLAGS -Wall -Wextra -fno-fast-math) + set(METAL_FLAGS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions) if(MLX_METAL_DEBUG) set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources) endif() diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h index c5c69c30b..b2e70ef8d 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -50,7 +50,7 @@ struct SubOp { struct ExpSubOp { template METAL_FUNC static constexpr T apply(T x, T y) { - return fast::exp(x - y); + return fast::exp2(x - y); } }; @@ -103,17 +103,24 @@ template < tidl.x * BQ * params->O_strides[2]; // Seqeunce // Prepare threadgroup memory - constexpr short padQ = 0; // 16 / sizeof(T); - constexpr short padK = 0; // 16 / sizeof(T); - constexpr short padV = 0; // 16 / sizeof(T); + constexpr short padQ = 16 / sizeof(T); + constexpr short padK = 16 / sizeof(T); + constexpr short padV = 16 / sizeof(T); constexpr short LDQ_tgp = BD + padQ; constexpr short LDK_tgp = BK + padK; constexpr short LDV_tgp = BD + padV; - threadgroup T Qs[BQ * (BD + padQ)]; - threadgroup T Ks[(BK + padK) * BD]; - threadgroup T Vs[BK * (BD + padV)]; + constexpr short tgp_mem_0 = (BK + padK) * (BD); + constexpr short tgp_mem_1 = BK * (BD + padV); + constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1; + + threadgroup T Q_smem[BQ * (BD + padQ)]; + threadgroup T KV_smem[tgp_mem_s]; + + threadgroup T* Qs = Q_smem; + threadgroup T* Ks = KV_smem; + threadgroup T* Vs = KV_smem; // Prepare block loaders using QBlockLoader = BlockLoaderT< @@ -151,7 +158,7 @@ template < VBlockLoader loader_v( V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); - TransformScale ts(static_cast(params->scale)); + TransformScale ts(static_cast(params->scale * 1.44269504089)); // Prepare MMA tiles constexpr short kFragSize = 8; // MMAFrag size @@ -174,7 +181,7 @@ template < MMATile Qtile; MMATile Ktile; MMATile Stile; - MMATile Vtile; + MMATile Vtile; MMATile Otile; Otile.clear(); @@ -224,11 +231,12 @@ template < loader_k.load_unsafe(); } - threadgroup_barrier(mem_flags::mem_threadgroup); - // Do S = Q @ K.T Stile.clear(); + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_UNROLL for (short dd = 0; dd < TD; dd++) { simdgroup_barrier(mem_flags::mem_none); @@ -264,7 +272,7 @@ template < } } - simdgroup_barrier(mem_flags::mem_none); + threadgroup_barrier(mem_flags::mem_threadgroup); // Load V blocks if (!align_K && kb == (params->NK_aligned)) { @@ -292,7 +300,7 @@ template < // Factor exp(rowmax(Si) - rowmax(Si-1)) STEEL_PRAGMA_UNROLL for (short i = 0; i < kRowsPT; ++i) { - factor[i] = fast::exp(max_score[i] - new_max[i]); + factor[i] = fast::exp2(max_score[i] - new_max[i]); } // Save max for next iteration @@ -316,12 +324,35 @@ template < // Load V into registers threadgroup_barrier(mem_flags::mem_threadgroup); - Vtile.template load(&Vs[Vs_offset]); - simdgroup_barrier(mem_flags::mem_none); + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TD; id++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + if constexpr (BD == 128) { + simdgroup_barrier(mem_flags::mem_none); + } - // Do O = S @ V - tile_matmad(Otile, Stile, Vtile, Otile); + const short kk = ik * kFragSize; + const short dd = id * kFragSize; + + Vtile.template load( + &Vs[Vs_offset + kk * LDV_tgp + dd]); + + if constexpr (BD == 128) { + simdgroup_barrier(mem_flags::mem_none); + } + + MMAFrag_acc_t::mma( + Otile.frag_at(iq, id), + Stile.frag_at(iq, ik), + Vtile.frag_at(0, 0), + Otile.frag_at(iq, id)); + } + } + } // Prepare for next iteration loader_k.next(); diff --git a/mlx/backend/metal/kernels/steel/attn/mma.h b/mlx/backend/metal/kernels/steel/attn/mma.h index 5ddd37ac6..525c50e8f 100644 --- a/mlx/backend/metal/kernels/steel/attn/mma.h +++ b/mlx/backend/metal/kernels/steel/attn/mma.h @@ -62,6 +62,12 @@ struct BaseMMAFrag { typedef metal::vec row_frag_type; typedef metal::vec col_frag_type; + template + using dtype_mat_t = typename metal::simdgroup_matrix; + + template + using dtype_frag_t = typename metal::vec; + METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id [[thread_index_in_simdgroup]]) { const short qid = simd_lane_id / 4; @@ -158,30 +164,32 @@ struct BaseMMAFrag { } } + template METAL_FUNC static constexpr void mma( thread frag_type& D, - thread frag_type& A, - thread frag_type& B, - thread frag_type& C) { + thread dtype_frag_t& A, + thread dtype_frag_t& B, + thread dtype_frag_t& C) { mat_type D_mat; - mat_type A_mat; - mat_type B_mat; - mat_type C_mat; + dtype_mat_t A_mat; + dtype_mat_t B_mat; + dtype_mat_t C_mat; - reinterpret_cast(A_mat.thread_elements()) = A; - reinterpret_cast(B_mat.thread_elements()) = B; - reinterpret_cast(C_mat.thread_elements()) = C; + reinterpret_cast&>(A_mat.thread_elements()) = A; + reinterpret_cast&>(B_mat.thread_elements()) = B; + reinterpret_cast&>(C_mat.thread_elements()) = C; mma(D_mat, A_mat, B_mat, C_mat); D = reinterpret_cast(D_mat.thread_elements()); } + template METAL_FUNC static constexpr void mma( thread mat_type& D, - thread mat_type& A, - thread mat_type& B, - thread mat_type& C) { + thread dtype_mat_t& A, + thread dtype_mat_t& B, + thread dtype_mat_t& C) { simdgroup_multiply_accumulate(D, A, B, C); } @@ -242,7 +250,7 @@ struct MMATile { typedef typename MMAFrag_t::mat_type mat_type; typedef typename MMAFrag_t::frag_type frag_type; - frag_type val_frags[kNumFrags] = {frag_type(0)}; + frag_type val_frags[kNumFrags]; // = {frag_type(0)}; METAL_FUNC MMATile() thread {} @@ -409,24 +417,37 @@ struct MMATile { } }; -template +template < + typename Dtype, + typename Atype, + typename Btype, + typename Ctype, + int M, + int N, + int K, + class MMAFragD, + class MMAFragA, + class MMAFragB, + class MMAFragC> METAL_FUNC void tile_matmad( - thread MMATile& D, - thread MMATile& A, - thread MMATile& B, - thread MMATile& C) { + thread MMATile& D, + thread MMATile& A, + thread MMATile& B, + thread MMATile& C) { STEEL_PRAGMA_UNROLL - for (short k = 0; k < K; ++k) { + for (short m = 0; m < M; ++m) { STEEL_PRAGMA_UNROLL - for (short m = 0; m < M; ++m) { + for (short n = 0; n < N; ++n) { + short m_serp = m; //(n % 2) ? (M - 1 - m) : m; + short n_serp = (m % 2) ? (N - 1 - n) : n; + STEEL_PRAGMA_UNROLL - for (short n = 0; n < N; ++n) { - short n_serp = (m % 2) ? (N - 1 - n) : n; - MMATile::MMAFrag_t::mma( - D.frag_at(m, n_serp), - A.frag_at(m, k), + for (short k = 0; k < K; ++k) { + MMAFragD::mma( + D.frag_at(m_serp, n_serp), + A.frag_at(m_serp, k), B.frag_at(k, n_serp), - C.frag_at(m, n_serp)); + C.frag_at(m_serp, n_serp)); } } } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index fadf594d0..1967c018f 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -693,7 +693,7 @@ array scaled_dot_product_attention( query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128); const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && - (query_head_dim == 64 || query_head_dim == 80); + (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); const bool supports_sdpa_full = query_sequence_length >= threshold && !mask && sdpa_full_supported_head_dim && stream.device == Device::gpu;