diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 7aa648533..db85a4d12 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -12,6 +12,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/paged_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index d0c872451..f1de1db91 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -102,6 +102,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/paged_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 1de5fa47c..119f69280 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -241,6 +241,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( int wn, bool transpose); +MTL::ComputePipelineState* get_paged_attention_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const std::string&); + // Create a GPU kernel template definition for JIT compilation template std::string diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 3ee88ca46..4c411c23f 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -109,6 +109,7 @@ if(NOT MLX_METAL_JIT) reduction/reduce_row.h) build_kernel(quantized quantized.h ${STEEL_HEADERS}) build_kernel(scan scan.h) + build_kernel(paged_attention paged_attention.h) build_kernel(softmax softmax.h) build_kernel(logsumexp logsumexp.h) build_kernel(sort sort.h) diff --git a/mlx/backend/metal/kernels/paged_attention.h b/mlx/backend/metal/kernels/paged_attention.h new file mode 100644 index 000000000..2e11ddc64 --- /dev/null +++ b/mlx/backend/metal/kernels/paged_attention.h @@ -0,0 +1,1196 @@ +// Updated from MLX commit has f70764a + +#include +#include +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +// ========================================== Generic vector types + +// A vector type to store Q, K, V elements. +template +struct Vec {}; + +// A vector type to store FP32 accumulators. +template +struct FloatVec {}; + +// Template vector operations. +template +inline Acc mul(A a, B b); + +template +inline float sum(T v); + +template +inline float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline float dot(T a, T b) { + return sum(mul(a, b)); +} + +// FP32 vector data types. +struct Float8_ { + float4 x; + float4 y; +}; + +template <> +struct Vec { + using Type = float; +}; +template <> +struct Vec { + using Type = float2; +}; +template <> +struct Vec { + using Type = float4; +}; +template <> +struct Vec { + using Type = Float8_; +}; + +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = float4; +}; +template <> +struct FloatVec { + using Type = Float8_; +}; + +template <> +inline float mul(float a, float b) { + return a * b; +} + +template <> +inline float2 mul(float2 a, float2 b) { + return a * b; +} + +template <> +inline float4 mul(float4 a, float4 b) { + return a * b; +} + +template <> +inline Float8_ mul(Float8_ a, Float8_ b) { + Float8_ c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> +inline float sum(float a) { + return a; +} + +template <> +inline float sum(float2 a) { + return a.x + a.y; +} + +template <> +inline float sum(float4 a) { + return a.x + a.y + a.z + a.w; +} + +template <> +inline float sum(Float8_ a) { + return sum(a.x) + sum(a.y); +} + +inline Float8_ fma(Float8_ a, Float8_ b, Float8_ c) { + Float8_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline void from_float(thread float& dst, float src) { + dst = src; +} +inline void from_float(thread float2& dst, float2 src) { + dst = src; +} +inline void from_float(thread float4& dst, float4 src) { + dst = src; +} +inline void from_float(thread Float8_& dst, Float8_ src) { + dst = src; +} + +// BF16 vector data types. +// #if defined(__HAVE_BFLOAT__) + +// struct Bfloat8_ { +// bfloat4 x; +// bfloat4 y; +// }; + +// template<> +// struct Vec { +// using Type = bfloat; +// }; +// template<> +// struct Vec { +// using Type = bfloat2; +// }; +// template<> +// struct Vec { +// using Type = bfloat4; +// }; +// template<> +// struct Vec { +// using Type = Bfloat8_; +// }; + +// template<> +// struct FloatVec { +// using Type = float; +// }; +// template<> +// struct FloatVec { +// using Type = float2; +// }; +// template<> +// struct FloatVec { +// using Type = float4; +// }; +// template<> +// struct FloatVec { +// using Type = Float8_; +// }; + +// template<> +// inline float mul(bfloat a, bfloat b) { +// return (float)a * (float)b; +// } +// template<> +// inline bfloat mul(bfloat a, bfloat b) { +// return a*b; +// } + +// template<> +// inline float2 mul(bfloat2 a, bfloat2 b) { +// return (float2)a * (float2)b; +// } +// template<> +// inline bfloat2 mul(bfloat2 a, bfloat2 b) { +// return a * b; +// } + +// template<> +// inline float4 mul(bfloat4 a, bfloat4 b) { +// return (float4)a * (float4)b; +// } +// template<> +// inline bfloat4 mul(bfloat4 a, bfloat4 b) { +// return a * b; +// } + +// template<> +// inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) { +// Float8_ c; +// c.x = mul(a.x, b.x); +// c.y = mul(a.y, b.y); +// return c; +// } +// template<> +// inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) { +// Bfloat8_ c; +// c.x = mul(a.x, b.x); +// c.y = mul(a.y, b.y); +// return c; +// } + +// template<> +// inline float sum(bfloat a) { +// return (float)a; +// } + +// template<> +// inline float sum(bfloat2 a) { +// return (float)a.x + (float)a.y; +// } + +// template<> +// inline float sum(bfloat4 a) { +// return sum(a.x) + sum(a.y); +// } + +// template<> +// inline float sum(Bfloat8_ a) { +// return sum(a.x) + sum(a.y); +// } + +// inline float fma(bfloat a, bfloat b, float c) { +// return (float)a * (float)b + c; +// } + +// inline float2 fma(bfloat2 a, bfloat2 b, float2 c) { +// return (float2)a * (float2)b + c; +// } + +// inline float4 fma(bfloat4 a, bfloat4 b, float4 c) { +// return (float4)a * (float4)b + c; +// } + +// inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) { +// Float8_ res; +// res.x = fma((float4)a.x, (float4)b.x, (float4)c.x); +// res.y = fma((float4)a.y, (float4)b.y, (float4)c.y); +// return res; +// } +// inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) { +// Bfloat8_ res; +// res.x = (bfloat4)fma((float4)a.x, (float4)b.x, (float4)c.x); +// res.y = (bfloat4)fma((float4)a.y, (float4)b.x, (float4)c.y); +// return c; +// } + +// inline void from_float(thread bfloat& dst, float src) { +// dst = static_cast(src); +// } +// inline void from_float(thread bfloat2& dst, float2 src) { +// dst.x = static_cast(src.x); +// dst.y = static_cast(src.y); +// } +// inline void from_float(thread bfloat4& dst, float4 src) { +// dst.x = static_cast(src.x); +// dst.y = static_cast(src.y); +// dst.z = static_cast(src.z); +// dst.w = static_cast(src.w); +// } +// inline void from_float(thread Bfloat8_& dst, Float8_ src) { +// bfloat4 x; +// bfloat4 y; +// from_float(x, src.x); +// from_float(y, src.y); +// dst.x = x; +// dst.y = y; +// } + +// #else + +struct Bfloat2_ { + bfloat16_t x; + bfloat16_t y; +}; + +struct Bfloat4_ { + Bfloat2_ x; + Bfloat2_ y; +}; + +struct Bfloat8_ { + Bfloat4_ x; + Bfloat4_ y; +}; + +template <> +struct Vec { + using Type = bfloat16_t; +}; +template <> +struct Vec { + using Type = Bfloat2_; +}; +template <> +struct Vec { + using Type = Bfloat4_; +}; +template <> +struct Vec { + using Type = Bfloat8_; +}; + +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = float4; +}; +template <> +struct FloatVec { + using Type = Float8_; +}; + +template <> +inline float mul(bfloat16_t a, bfloat16_t b) { + return (float)a * (float)b; +} +template <> +inline bfloat16_t mul(bfloat16_t a, bfloat16_t b) { + return a * b; +} + +template <> +inline float2 mul(Bfloat2_ a, Bfloat2_ b) { + float2 a_f((float)a.x, (float)a.y); + float2 b_f((float)b.x, (float)b.y); + return a_f * b_f; +} +template <> +inline Bfloat2_ mul(Bfloat2_ a, Bfloat2_ b) { + Bfloat2_ c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> +inline float4 mul(Bfloat4_ a, Bfloat4_ b) { + float2 x = mul(a.x, b.x); + float2 y = mul(a.y, b.y); + float4 c; + c.x = x.x; + c.y = x.y; + c.z = y.x; + c.w = y.y; + return c; +} +template <> +inline Bfloat4_ mul(Bfloat4_ a, Bfloat4_ b) { + Bfloat4_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> +inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) { + Float8_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} +template <> +inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) { + Bfloat8_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> +inline float sum(bfloat16_t a) { + return (float)a; +} + +template <> +inline float sum(Bfloat2_ a) { + return (float)a.x + (float)a.y; +} + +template <> +inline float sum(Bfloat4_ a) { + return sum(a.x) + sum(a.y); +} + +template <> +inline float sum(Bfloat8_ a) { + return sum(a.x) + sum(a.y); +} + +inline float fma(bfloat16_t a, bfloat16_t b, float c) { + return (float)a * (float)b + c; +} +inline bfloat16_t fma(bfloat16_t a, bfloat16_t b, bfloat16_t c) { + return a * b + c; +} + +inline float2 fma(Bfloat2_ a, Bfloat2_ b, float2 c) { + float2 a_f((float)a.x, (float)a.y); + float2 b_f((float)b.x, (float)b.y); + return a_f * b_f + c; +} +inline Bfloat2_ fma(Bfloat2_ a, Bfloat2_ b, Bfloat2_ c) { + Bfloat2_ res; + res.x = a.x * b.x + c.x; + res.y = a.y * b.y + c.y; + return res; +} + +inline float4 fma(Bfloat4_ a, Bfloat4_ b, float4 c) { + float4 res; + res.x = fma(a.x.x, b.x.x, c.x); + res.y = fma(a.x.y, b.x.y, c.y); + res.z = fma(a.y.x, b.y.x, c.z); + res.w = fma(a.y.y, b.y.y, c.w); + return res; +} +inline Bfloat4_ fma(Bfloat4_ a, Bfloat4_ b, Bfloat4_ c) { + Bfloat4_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) { + float4 x = fma(a.x, b.x, c.x); + float4 y = fma(a.y, b.y, c.y); + Float8_ res; + res.x = x; + res.y = y; + return res; +} +inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) { + Bfloat8_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline void from_float(thread bfloat16_t& dst, float src) { + dst = static_cast(src); +} +inline void from_float(thread Bfloat2_& dst, float2 src) { + dst.x = static_cast(src.x); + dst.y = static_cast(src.y); +} +inline void from_float(thread Bfloat4_& dst, float4 src) { + dst.x.x = static_cast(src.x); + dst.x.y = static_cast(src.y); + dst.y.x = static_cast(src.z); + dst.y.y = static_cast(src.w); +} +inline void from_float(thread Bfloat8_& dst, Float8_ src) { + Bfloat4_ x; + Bfloat4_ y; + from_float(x, src.x); + from_float(y, src.y); + dst.x = x; + dst.y = y; +} + +// #endif + +// FP16 vector data types. +struct Half8_ { + half4 x; + half4 y; +}; + +template <> +struct Vec { + using Type = half; +}; +template <> +struct Vec { + using Type = half2; +}; +template <> +struct Vec { + using Type = half4; +}; +template <> +struct Vec { + using Type = Half8_; +}; + +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = float4; +}; +template <> +struct FloatVec { + using Type = Float8_; +}; + +template <> +inline float mul(half a, half b) { + return (float)a * (float)b; +} +template <> +inline half mul(half a, half b) { + return a * b; +} + +template <> +inline float2 mul(half2 a, half2 b) { + return (float2)a * (float2)b; +} +template <> +inline half2 mul(half2 a, half2 b) { + return a * b; +} + +template <> +inline float4 mul(half4 a, half4 b) { + return (float4)a * (float4)b; +} +template <> +inline half4 mul(half4 a, half4 b) { + return a * b; +} + +template <> +inline Float8_ mul(Half8_ a, Half8_ b) { + float4 x = mul(a.x, b.x); + float4 y = mul(a.y, b.y); + Float8_ c; + c.x = x; + c.y = y; + return c; +} +template <> +inline Half8_ mul(Half8_ a, Half8_ b) { + Half8_ c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> +inline float sum(half a) { + return (float)a; +} + +template <> +inline float sum(half2 a) { + return (float)a.x + (float)a.y; +} + +template <> +inline float sum(half4 a) { + return a.x + a.y + a.z + a.w; +} + +template <> +inline float sum(Half8_ a) { + return sum(a.x) + sum(a.y); +} + +inline float fma(half a, half b, float c) { + return (float)a * (float)b + c; +} + +inline float2 fma(half2 a, half2 b, float2 c) { + return (float2)a * (float2)b + c; +} + +inline float4 fma(half4 a, half4 b, float4 c) { + return (float4)a * (float4)b + c; +} + +inline Float8_ fma(Half8_ a, Half8_ b, Float8_ c) { + float4 x = fma(a.x, b.x, c.x); + float4 y = fma(a.y, b.y, c.y); + Float8_ res; + res.x = x; + res.y = y; + return res; +} +inline Half8_ fma(Half8_ a, Half8_ b, Half8_ c) { + Half8_ res; + res.x = fma(a.x, b.x, c.x); + res.y = fma(a.y, b.y, c.y); + return res; +} + +inline void from_float(thread half& dst, float src) { + dst = static_cast(src); +} +inline void from_float(thread half2& dst, float2 src) { + dst.x = static_cast(src.x); + dst.y = static_cast(src.y); +} +inline void from_float(thread half4& dst, float4 src) { + dst.x = static_cast(src.x); + dst.y = static_cast(src.y); + dst.z = static_cast(src.z); + dst.w = static_cast(src.w); +} +inline void from_float(thread Half8_& dst, Float8_ src) { + half4 x; + half4 y; + from_float(x, src.x); + from_float(y, src.y); + dst.x = x; + dst.y = y; +} + +// ========================================== Dot product utilities + +// TODO(EricLBuehler): optimize with vectorization +template +inline float qk_dot_(const threadgroup Vec (&q)[N], const thread Vec (&k)[N]) { + // Compute the parallel products for Q*K^T (treat vector lanes separately). + using A_vec = typename FloatVec::Type; + A_vec qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += simd_shuffle_xor(qk, mask); + } + return qk; +} + +template +struct Qk_dot { + template + static inline float dot( + const threadgroup Vec (&q)[N], + const thread Vec (&k)[N]) { + return qk_dot_(q, k); + } +}; + +// ========================================== Block sum utility + +// Utility function for attention softmax. +template +inline float block_sum( + threadgroup float* red_smem, + float sum, + uint simd_tid, + uint simd_lid) { + // Compute the sum per simdgroup. +#pragma unroll + for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) { + sum += simd_shuffle_xor(sum, mask); + } + + // Simd leaders store the data to shared memory. + if (simd_lid == 0) { + red_smem[simd_tid] = sum; + } + + // Make sure the data is in shared memory. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // The warps compute the final sums. + if (simd_lid < NUM_WARPS) { + sum = red_smem[simd_lid]; + } + + // Parallel reduction inside the simd group. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += simd_shuffle_xor(sum, mask); + } + + // Broadcast to other threads. + return simd_shuffle(sum, 0); +} + +// ========================================== Paged Attention kernel + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +constant bool use_partitioning [[function_constant(10)]]; +constant bool use_alibi [[function_constant(20)]]; + +template < + typename T, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS, + int NUM_SIMD_LANES, + int PARTITION_SIZE = 0> +[[kernel]] void paged_attention( + device float* exp_sums + [[buffer(0), function_constant(use_partitioning)]], // [num_seqs, num_heads, + // max_num_partitions] + device float* max_logits + [[buffer(1), function_constant(use_partitioning)]], // [num_seqs, num_heads, + // max_num_partitions] + device T* out + [[buffer(2)]], // [num_seqs, num_heads, max_num_partitions, head_size] + device const T* q [[buffer(3)]], // [num_seqs, num_heads, head_size] + device const T* k_cache + [[buffer(4)]], // [num_blocks, num_kv_heads, head_size/x, block_size, x] + device const T* v_cache + [[buffer(5)]], // [num_blocks, num_kv_heads, head_size, block_size] + const constant int& num_kv_heads [[buffer(6)]], // [num_heads] + const constant float& scale [[buffer(7)]], + const constant float& softcapping [[buffer(8)]], + device const uint32_t* block_tables + [[buffer(9)]], // [num_seqs, max_num_blocks_per_seq] + device const uint32_t* context_lens [[buffer(10)]], // [num_seqs] + const constant int& max_num_blocks_per_seq [[buffer(11)]], + device const float* alibi_slopes + [[buffer(12), function_constant(use_alibi)]], // [num_heads] + const constant int& q_stride [[buffer(13)]], + const constant int& kv_block_stride [[buffer(14)]], + const constant int& kv_head_stride [[buffer(15)]], + threadgroup char* shared_mem [[threadgroup(0)]], + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], + uint3 threadgroups_per_grid [[threadgroups_per_grid]], + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], + uint simd_tid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + const int seq_idx = threadgroup_position_in_grid.y; + const int partition_idx = threadgroup_position_in_grid.z; + const int max_num_partitions = threadgroups_per_grid.z; + const int thread_idx = thread_position_in_threadgroup.x; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const uint32_t context_len = context_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int num_tokens = end_token_idx - start_token_idx; + + constexpr int THREAD_GROUP_SIZE = MAX(NUM_SIMD_LANES / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, NUM_SIMD_LANES); + constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES; + const int warp_idx = simd_tid; + const int lane = simd_lid; + + const int head_idx = threadgroup_position_in_grid.x; + const int num_heads = threadgroups_per_grid.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + const float alibi_slope = !use_alibi ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the thread group size is 4, then the first thread in the + // group has 0, 4, 8, ... th vectors of the query, and the second thread has + // 1, 5, 9, ... th vectors of the query, and so on. + const device T* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + threadgroup Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Use fp32 on softmax logits for better accuracy + threadgroup float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction + threadgroup float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(T); + float qk_max = -FLT_MAX; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + const device uint32_t* block_table = + block_tables + seq_idx * max_num_blocks_per_seq; + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE: The block number is stored in int32. However, we cast it to int64 + // because int32 can lead to overflow when this variable is multiplied by + // large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the thread group size is 4, then the first thread in the + // group has 0, 4, 8, ... th vectors of the key, and the second thread has + // 1, 5, 9, ... th vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * NUM_SIMD_LANES) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const device T* k_ptr = k_cache + + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * + Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); + + // Apply softcapping + if (softcapping != 1.0) { + qk = precise::tanh(qk / softcapping) * softcapping; + } + + // Add the ALiBi bias if slopes are given. + qk += + (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE: It is required to zero out the masked logits. + const bool mask = token_idx >= context_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : max(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = NUM_SIMD_LANES / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = simd_shuffle(qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = exp(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum( + &red_smem[NUM_WARPS], exp_sum, simd_tid, simd_lid); + + // Compute softmax. + const float inv_sum = divide(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0 && use_partitioning) { + device float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *max_logits_ptr = qk_max; + device float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *exp_sums_ptr = exp_sum; + } + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(T), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = NUM_SIMD_LANES / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + // NOTE: We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + T zero_value = 0; + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE: The block number is stored in int32. However, we cast it to int64 + // because int32 can lead to overflow when this variable is multiplied by + // large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + Float_L_vec logits_float_vec = *reinterpret_cast( + logits + token_idx - start_token_idx); + from_float(logits_vec, logits_float_vec); + + const device T* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + // NOTE: When v_vec contains the tokens that are out of the context, + // we should explicitly zero out the values since they may contain NaNs. + // See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + V_vec v_vec = *reinterpret_cast(v_ptr + offset); + if (block_idx == num_context_blocks - 1) { + thread T* v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = + token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += simd_shuffle_xor(acc, mask); + } + accs[i] = acc; + } + + // NOTE: A barrier is required because the shared memory space for logits + // is reused for the output. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Perform reduction across warps. + threadgroup float* out_smem = + reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + threadgroup float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Lower warps update the output. + if (warp_idx < mid) { + const threadgroup float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Write the final output. + if (warp_idx == 0) { + device T* out_ptr = out + + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + *(out_ptr + row_idx) = T(accs[i]); + } + } + } +} + +template < + typename T, + int HEAD_SIZE, + int NUM_THREADS, + int NUM_SIMD_LANES, + int PARTITION_SIZE = 0> +[[kernel]] void paged_attention_v2_reduce( + device T* out [[buffer(0)]], + const device float* exp_sums [[buffer(1)]], + const device float* max_logits [[buffer(2)]], + const device T* tmp_out [[buffer(3)]], + device uint32_t* context_lens [[buffer(4)]], + const constant int& max_num_partitions [[buffer(5)]], + threadgroup char* shared_mem [[threadgroup(0)]], + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], + uint3 threadgroups_per_grid [[threadgroups_per_grid]], + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], + uint3 threads_per_threadgroup [[threads_per_threadgroup]], + uint simd_tid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + const int num_heads = threadgroups_per_grid.x; + const int head_idx = threadgroup_position_in_grid.x; + const int seq_idx = threadgroup_position_in_grid.y; + const uint32_t context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + device T* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const device T* tmp_out_ptr = tmp_out + + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE; + i += threads_per_threadgroup.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES; + const int warp_idx = simd_tid; + const int lane = simd_lid; + + // Workspace for reduction. + threadgroup float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + threadgroup float* shared_max_logits = + reinterpret_cast(shared_mem); + const device float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = thread_position_in_threadgroup.x; i < num_partitions; + i += threads_per_threadgroup.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = max(max_logit, l); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) { + max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = simd_shuffle(max_logit, 0); + + // Load rescaled exp sums to shared memory. + threadgroup float* shared_exp_sums = reinterpret_cast( + shared_mem + sizeof(float) * num_partitions); + const device float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = thread_position_in_threadgroup.x; i < num_partitions; + i += threads_per_threadgroup.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * exp(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + global_exp_sum = block_sum( + &red_smem[NUM_WARPS], global_exp_sum, simd_tid, simd_lid); + const float inv_global_exp_sum = divide(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const device T* tmp_out_ptr = tmp_out + + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + device T* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE; + i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; + } + out_ptr[i] = T(acc); + } +} diff --git a/mlx/backend/metal/kernels/paged_attention.metal b/mlx/backend/metal/kernels/paged_attention.metal new file mode 100644 index 000000000..191e155ab --- /dev/null +++ b/mlx/backend/metal/kernels/paged_attention.metal @@ -0,0 +1,131 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/metal/kernels/paged_attention.h" +#include "mlx/backend/metal/kernels/utils.h" + +#define instantiate_paged_attention_inner( \ + type, head_size, block_size, num_threads, num_simd_lanes, partition_size) \ + template \ + [[host_name("paged_attention_" #type "_hs" #head_size "_bs" #block_size \ + "_nt" #num_threads "_nsl" #num_simd_lanes \ + "_ps" #partition_size)]] [[kernel]] void \ + paged_attention< \ + type, \ + head_size, \ + block_size, \ + num_threads, \ + num_simd_lanes, \ + partition_size>( \ + device float* exp_sums \ + [[buffer(0), function_constant(use_partitioning)]], \ + device float* max_logits \ + [[buffer(1), function_constant(use_partitioning)]], \ + device type* out [[buffer(2)]], \ + device const type* q [[buffer(3)]], \ + device const type* k_cache [[buffer(4)]], \ + device const type* v_cache [[buffer(5)]], \ + const constant int& num_kv_heads [[buffer(6)]], \ + const constant float& scale [[buffer(7)]], \ + const constant float& softcapping [[buffer(8)]], \ + device const uint32_t* block_tables [[buffer(9)]], \ + device const uint32_t* context_lens [[buffer(10)]], \ + const constant int& max_num_blocks_per_seq [[buffer(11)]], \ + device const float* alibi_slopes \ + [[buffer(12), function_constant(use_alibi)]], \ + const constant int& q_stride [[buffer(13)]], \ + const constant int& kv_block_stride [[buffer(14)]], \ + const constant int& kv_head_stride [[buffer(15)]], \ + threadgroup char* shared_mem [[threadgroup(0)]], \ + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ + uint3 threadgroups_per_grid [[threadgroups_per_grid]], \ + uint3 thread_position_in_threadgroup \ + [[thread_position_in_threadgroup]], \ + uint simd_tid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_paged_attention_v2_reduce_inner( \ + type, head_size, num_threads, num_simd_lanes, partition_size) \ + template [[host_name("paged_attention_v2_reduce_" #type "_hs" #head_size \ + "_nt" #num_threads "_nsl" #num_simd_lanes \ + "_ps" #partition_size)]] [[kernel]] void \ + paged_attention_v2_reduce< \ + type, \ + head_size, \ + num_threads, \ + num_simd_lanes, \ + partition_size>( \ + device type * out [[buffer(0)]], \ + const device float* exp_sums [[buffer(1)]], \ + const device float* max_logits [[buffer(2)]], \ + const device type* tmp_out [[buffer(3)]], \ + device uint32_t* context_lens [[buffer(4)]], \ + const constant int& max_num_partitions [[buffer(5)]], \ + threadgroup char* shared_mem [[threadgroup(0)]], \ + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ + uint3 threadgroups_per_grid [[threadgroups_per_grid]], \ + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \ + uint3 threads_per_threadgroup [[threads_per_threadgroup]], \ + uint simd_tid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_paged_attention_heads( \ + type, block_size, num_threads, num_simd_lanes, partition_size) \ + instantiate_paged_attention_inner( \ + type, 64, block_size, num_threads, num_simd_lanes, partition_size); \ + instantiate_paged_attention_inner( \ + type, 80, block_size, num_threads, num_simd_lanes, partition_size); \ + instantiate_paged_attention_inner( \ + type, 96, block_size, num_threads, num_simd_lanes, partition_size); \ + instantiate_paged_attention_inner( \ + type, 112, block_size, num_threads, num_simd_lanes, partition_size); \ + instantiate_paged_attention_inner( \ + type, 128, block_size, num_threads, num_simd_lanes, partition_size); \ + instantiate_paged_attention_inner( \ + type, 192, block_size, num_threads, num_simd_lanes, partition_size); \ + instantiate_paged_attention_inner( \ + type, 256, block_size, num_threads, num_simd_lanes, partition_size); + +#define instantiate_paged_attention_v2_reduce_heads( \ + type, num_threads, num_simd_lanes, partition_size) \ + instantiate_paged_attention_v2_reduce_inner( \ + type, 64, num_threads, num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner( \ + type, 80, num_threads, num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner( \ + type, 96, num_threads, num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner( \ + type, 112, num_threads, num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner( \ + type, 128, num_threads, num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner( \ + type, 192, num_threads, num_simd_lanes, partition_size); \ + instantiate_paged_attention_v2_reduce_inner( \ + type, 256, num_threads, num_simd_lanes, partition_size); + +#define instantiate_paged_attention_block_size( \ + type, num_threads, num_simd_lanes, partition_size) \ + instantiate_paged_attention_heads( \ + type, 8, num_threads, num_simd_lanes, partition_size); \ + instantiate_paged_attention_heads( \ + type, 16, num_threads, num_simd_lanes, partition_size); \ + instantiate_paged_attention_heads( \ + type, 32, num_threads, num_simd_lanes, partition_size); + +// TODO: tune num_threads = 256 +// NOTE: partition_size = 0 +#define instantiate_paged_attention_v1(type, num_simd_lanes) \ + instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 0); + +// TODO: tune num_threads = 256 +// NOTE: partition_size = 512 +#define instantiate_paged_attention_v2(type, num_simd_lanes) \ + instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 512); \ + instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512); + +instantiate_paged_attention_v1(float, 32); +instantiate_paged_attention_v1(bfloat16_t, 32); +instantiate_paged_attention_v1(half, 32); + +instantiate_paged_attention_v2(float, 32); +instantiate_paged_attention_v2(bfloat16_t, 32); +instantiate_paged_attention_v2(half, 32); diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index b0375e37f..5e203307b 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -288,4 +288,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( return d.get_kernel(kernel_name, hash_name, func_consts); } +MTL::ComputePipelineState* get_paged_attention_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const std::string&) { + return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/paged_attention.cpp b/mlx/backend/metal/paged_attention.cpp new file mode 100644 index 000000000..95cae1222 --- /dev/null +++ b/mlx/backend/metal/paged_attention.cpp @@ -0,0 +1,324 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/paged_attention_primitives.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mlx::core::paged_attention { + +static void run_paged_attention( + const array& q, + const array& k_cache, + const array& v_cache, + const array& block_tables, + const array& context_lens, + const int head_size, + const int block_size, + const int num_kv_heads, + const float scale, + const float softcapping, + const int max_context_len, + const int max_num_blocks_per_seq, + const bool use_partitioning, + const std::optional alibi, + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + const int num_heads, + const int num_seqs, + array& out, + metal::Device& d, + const Stream& s) { + const int partition_size = use_partitioning ? 512 : 0; + const int num_threads = 256; + const int num_simd_lanes = 32; + const bool use_alibi = alibi.has_value(); + + std::string type_string = get_type_string(q.dtype()); + std::string kname; + kname.reserve(64); + concatenate( + kname, + "paged_attention_", + type_string, + "_hs", + head_size, + "_bs", + block_size, + "_nt", + num_threads, + "_nsl", + num_simd_lanes, + "_ps", + partition_size); + + auto template_def = get_template_definition( + kname, + "paged_attention", + type_string, + head_size, + block_size, + num_threads, + num_simd_lanes, + partition_size); + + // Encode and dispatch kernel + metal::MTLFCList func_consts = { + {use_partitioning, MTL::DataType::DataTypeBool, 10}, + {use_alibi, MTL::DataType::DataTypeBool, 20}, + }; + + std::string hash_name = kname; + auto kernel = get_paged_attention_kernel( + d, kname, hash_name, func_consts, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + int local_max_num_partitions = 1; + if (use_partitioning) { + local_max_num_partitions = + (max_context_len + partition_size - 1) / partition_size; + } + + int logits_size = use_partitioning ? partition_size * size_of(float32) : 0; + int outputs_size = use_partitioning + ? ((num_threads / num_simd_lanes) / 2) * head_size * size_of(float32) + : 0; + int shared_mem_size = + use_partitioning ? std::max(logits_size, outputs_size) : 0; + if (use_partitioning) { + compute_encoder.set_threadgroup_memory_length(shared_mem_size, 0); + } + + if (use_partitioning) { + auto tmp_out = array( + {num_seqs, num_heads, local_max_num_partitions, head_size}, float32); + tmp_out.set_data(allocator::malloc(tmp_out.nbytes())); + auto exp_sums = + array({num_seqs, num_heads, local_max_num_partitions}, float32); + exp_sums.set_data(allocator::malloc(exp_sums.nbytes())); + + std::vector temporaries = {tmp_out, exp_sums}; + + compute_encoder.set_output_array(tmp_out, 0); + compute_encoder.set_output_array(exp_sums, 1); + compute_encoder.set_output_array(out, 2); + compute_encoder.set_input_array(q, 3); + compute_encoder.set_input_array(k_cache, 4); + compute_encoder.set_input_array(v_cache, 5); + + compute_encoder.set_bytes(num_kv_heads, 6); + compute_encoder.set_bytes(scale, 7); + compute_encoder.set_bytes(softcapping, 8); + + compute_encoder.set_input_array(block_tables, 9); + compute_encoder.set_input_array(context_lens, 10); + + compute_encoder.set_bytes(max_num_blocks_per_seq, 11); + + if (use_alibi) { + compute_encoder.set_input_array(alibi.value(), 12); + } + + compute_encoder.set_bytes(q_stride, 13); + compute_encoder.set_bytes(kv_block_stride, 14); + compute_encoder.set_bytes(kv_head_stride, 15); + + MTL::Size grid_dims(num_heads, num_seqs, local_max_num_partitions); + MTL::Size group_dims(num_threads, 1, 1); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + d.add_temporaries(std::move(temporaries), s.index); + } else { + compute_encoder.set_output_array(out, 2); + compute_encoder.set_input_array(q, 3); + compute_encoder.set_input_array(k_cache, 4); + compute_encoder.set_input_array(v_cache, 5); + + compute_encoder.set_bytes(num_kv_heads, 6); + compute_encoder.set_bytes(scale, 7); + compute_encoder.set_bytes(softcapping, 8); + + compute_encoder.set_input_array(block_tables, 9); + compute_encoder.set_input_array(context_lens, 10); + + compute_encoder.set_bytes(max_num_blocks_per_seq, 11); + + if (use_alibi) { + compute_encoder.set_input_array(alibi.value(), 12); + } + + compute_encoder.set_bytes(q_stride, 13); + compute_encoder.set_bytes(kv_block_stride, 14); + compute_encoder.set_bytes(kv_head_stride, 15); + + MTL::Size grid_dims(num_heads, num_seqs, 1); + MTL::Size group_dims(num_threads, 1, 1); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + } +} + +void paged_attention_v1( + const array& q, + const array& k_cache, + const array& v_cache, + const array& block_tables, + const array& context_lens, + const int head_size, + const int block_size, + const int num_kv_heads, + const float scale, + const float softcapping, + const int max_context_len, + const int max_num_blocks_per_seq, + const std::optional alibi, + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + const int num_heads, + const int num_seqs, + array& out, + metal::Device& d, + const Stream& s) { + run_paged_attention( + q, + k_cache, + v_cache, + block_tables, + context_lens, + head_size, + block_size, + num_kv_heads, + scale, + softcapping, + max_context_len, + max_num_blocks_per_seq, + /*use_partitioning=*/false, + alibi, + q_stride, + kv_block_stride, + kv_head_stride, + num_heads, + num_seqs, + out, + d, + s); +} + +void paged_attention_v2( + const array& q, + const array& k_cache, + const array& v_cache, + const array& block_tables, + const array& context_lens, + const int head_size, + const int block_size, + const int num_kv_heads, + const float scale, + const float softcapping, + const int max_context_len, + const int max_num_blocks_per_seq, + const int /* max_num_partitions */, + const std::optional alibi, + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + const int num_heads, + const int num_seqs, + array& out, + metal::Device& d, + const Stream& s) { + run_paged_attention( + q, + k_cache, + v_cache, + block_tables, + context_lens, + head_size, + block_size, + num_kv_heads, + scale, + softcapping, + max_context_len, + max_num_blocks_per_seq, + /*use_partitioning=*/true, + alibi, + q_stride, + kv_block_stride, + kv_head_stride, + num_heads, + num_seqs, + out, + d, + s); +} + +void PagedAttention::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& d = metal::device(s.device); + + out.set_data(allocator::malloc(out.nbytes())); + + auto& q = inputs[0]; + auto& k_cache = inputs[1]; + auto& v_cache = inputs[2]; + auto& block_tables = inputs[3]; + auto& context_lens = inputs[4]; + const auto alibi_slopes = + inputs.size() == 6 ? std::optional{inputs[5]} : std::nullopt; + + if (use_v1_) { + paged_attention_v1( + q, + k_cache, + v_cache, + block_tables, + context_lens, + head_size_, + block_size_, + num_kv_heads_, + softmax_scale_, + softcapping_.value_or(1.), + max_context_len_, + max_num_blocks_per_seq_, + alibi_slopes, + q_stride_, + kv_block_stride_, + kv_head_stride_, + num_heads_, + num_seqs_, + out, + d, + s); + } else { + paged_attention_v2( + q, + k_cache, + v_cache, + block_tables, + context_lens, + head_size_, + block_size_, + num_kv_heads_, + softmax_scale_, + softcapping_.value_or(1.), + max_context_len_, + max_num_blocks_per_seq_, + max_num_partitions_, + alibi_slopes, + q_stride_, + kv_block_stride_, + kv_head_stride_, + num_heads_, + num_seqs_, + out, + d, + s); + } +} +} // namespace mlx::core::paged_attention \ No newline at end of file diff --git a/mlx/mlx.h b/mlx/mlx.h index de3ee392a..c44b055b3 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -17,6 +17,7 @@ #include "mlx/linalg.h" #include "mlx/memory.h" #include "mlx/ops.h" +#include "mlx/paged_attention.h" #include "mlx/random.h" #include "mlx/stream.h" #include "mlx/transforms.h" diff --git a/mlx/paged_attention.cpp b/mlx/paged_attention.cpp new file mode 100644 index 000000000..037eb23ac --- /dev/null +++ b/mlx/paged_attention.cpp @@ -0,0 +1,170 @@ +// Copyright © 2023-2024 Apple Inc. + +// Required for using M_PI in MSVC. +#define _USE_MATH_DEFINES + +#include +#include +#include +#include +#include +#include +#include + +#include "mlx/paged_attention_primitives.h" +#include "mlx/utils.h" + +namespace mlx::core::paged_attention { + +array paged_attention( + const array& q, + const array& k_cache, + const array& v_cache, + const array& block_tables, + const array& context_lens, + int max_context_len, + float softmax_scale, + std::optional alibi_slopes = std::nullopt, + std::optional softcapping = std::nullopt, + StreamOrDevice s_ = {}) { + auto s = to_stream(s_); + + // supported dtypes + if (!issubdtype(q.dtype(), floating)) { + throw std::invalid_argument( + "[paged_attention] Only real floating types are supported"); + } + if (!(q.dtype() == k_cache.dtype() && k_cache.dtype() == v_cache.dtype())) { + throw std::invalid_argument( + "[paged_attention] q/k_cache/v_cache dtype must match"); + } + if (!(block_tables.dtype() == uint32 && context_lens.dtype() == uint32)) { + throw std::invalid_argument( + "[paged_attention] block_tables/context_lens dtype must be uint32"); + } + + // rank checks + if (q.ndim() != 3) + throw std::invalid_argument("[paged_attention] `q` must be rank-3"); + if (k_cache.ndim() != 5) + throw std::invalid_argument("[paged_attention] `k_cache` must be rank-5"); + if (v_cache.ndim() != 4) + throw std::invalid_argument("[paged_attention] `v_cache` must be rank-4"); + if (block_tables.ndim() != 2) + throw std::invalid_argument( + "[paged_attention] `block_tables` must be rank-2"); + if (context_lens.ndim() != 1) + throw std::invalid_argument( + "[paged_attention] `context_lens` must be rank-1"); + + // 4. Shape consistency + const auto& q_shape = q.shape(); // [num_seqs, num_heads, head_size] + const auto& kc_shape = k_cache.shape(); + const auto& vc_shape = v_cache.shape(); + const auto& bt_shape = block_tables.shape(); + const auto& cl_shape = context_lens.shape(); + + int num_seqs = q_shape[0]; + int num_heads = q_shape[1]; + int head_size = q_shape[2]; + + // Allowed head sizes + switch (head_size) { + case 64: + case 80: + case 96: + case 112: + case 128: + case 192: + case 256: + break; + default: + throw std::invalid_argument( + "[paged_attention] `head_size` must be one of " + "{64, 80, 96, 112, 128, 192, 256}"); + } + + int max_num_blocks_per_seq = bt_shape[1]; + + // block_tables first dimension must match num_seqs + if (bt_shape[0] != num_seqs) { + std::stringstream ss; + ss << "[paged_attention] block_tables.shape[0] (" << bt_shape[0] + << ") must equal q.shape[0] (" << num_seqs << ")"; + throw std::invalid_argument(ss.str()); + } + + // Extract k_cache dimensions + int num_blocks = kc_shape[0]; + int num_kv_heads = kc_shape[1]; + int head_size_kc = kc_shape[2]; + int block_size = kc_shape[3]; + int x = kc_shape[4]; + + if (head_size_kc * x != head_size) { + std::stringstream ss; + ss << "[paged_attention] k_cache head_size (" << head_size_kc << " * " << x + << ") must equal q head_size (" << head_size << ")"; + throw std::invalid_argument(ss.str()); + } + + // v_cache must match the derived dimensions + if (!(vc_shape[0] == num_blocks && vc_shape[1] == num_kv_heads && + vc_shape[2] == head_size && vc_shape[3] == block_size)) { + throw std::invalid_argument( + "[paged_attention] `v_cache` shape mismatch with `k_cache`/`q`"); + } + + // context_lens length must match num_seqs + if (cl_shape[0] != num_seqs) { + std::stringstream ss; + ss << "paged_attention: context_lens length (" << cl_shape[0] + << ") must equal q.shape[0] (" << num_seqs << ")"; + throw std::invalid_argument(ss.str()); + } + + constexpr int partition_size = 512; + int max_num_partitions = + (max_context_len + partition_size - 1) / partition_size; // ceil‑div + bool use_v1 = ((max_num_partitions == 1) || (num_seqs * num_heads > 512)) && + (partition_size % block_size == 0); + + auto out_shape = q.shape(); + + auto inputs = std::vector{ + std::move(q), + std::move(k_cache), + std::move(v_cache), + std::move(block_tables), + std::move(context_lens)}; + if (alibi_slopes.has_value()) { + inputs.push_back(std::move(alibi_slopes.value())); + } + + int q_stride = q.strides()[0]; + int kv_block_stride = k_cache.strides()[0]; + int kv_head_stride = k_cache.strides()[1]; + + return array( + std::move(out_shape), + q.dtype(), + std::make_shared( + to_stream(s), + use_v1, + max_context_len, + head_size, + block_size, + num_kv_heads, + softmax_scale, + max_num_blocks_per_seq, + max_num_partitions, + q_stride, + kv_block_stride, + kv_head_stride, + num_heads, + num_seqs, + softcapping), + inputs); +} + +} // namespace mlx::core::paged_attention diff --git a/mlx/paged_attention.h b/mlx/paged_attention.h new file mode 100644 index 000000000..01999b2ea --- /dev/null +++ b/mlx/paged_attention.h @@ -0,0 +1,34 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/stream.h" +#include "mlx/utils.h" + +namespace mlx::core::paged_attention { + +/** + * \defgroup ops Paged attention operations + * @{ + */ + +/** PagedAttention operation. */ +array paged_attention( + const array& q, + const array& k_cache, + const array& v_cache, + const array& block_tables, + const array& context_lens, + int max_context_len, + float softmax_scale, + std::optional alibi_slopes = std::nullopt, + std::optional softcapping = std::nullopt, + StreamOrDevice s_ = {}); + +/** @} */ + +} // namespace mlx::core::paged_attention diff --git a/mlx/paged_attention_primitives.h b/mlx/paged_attention_primitives.h new file mode 100644 index 000000000..353d21a5c --- /dev/null +++ b/mlx/paged_attention_primitives.h @@ -0,0 +1,82 @@ +// Copyright © 2023-2024 Apple Inc. + +// Required for using M_PI in MSVC. +#define _USE_MATH_DEFINES + +#include + +#include "mlx/primitives.h" + +namespace mlx::core::paged_attention { + +class PagedAttention : public UnaryPrimitive { + public: + explicit PagedAttention( + Stream stream, + bool use_v1, + int max_context_len, + int head_size, + int block_size, + int num_kv_heads, + int max_num_blocks_per_seq, + int max_num_partitions, + int q_stride, + int kv_block_stride, + int kv_head_stride, + int num_heads, + int num_seqs, + float softmax_scale, + std::optional softcapping = std::nullopt) + : UnaryPrimitive(stream), + use_v1_(use_v1), + max_context_len_(max_context_len), + head_size_(head_size), + block_size_(block_size), + num_kv_heads_(num_kv_heads), + max_num_blocks_per_seq_(max_num_blocks_per_seq), + max_num_partitions_(max_num_partitions), + q_stride_(q_stride), + kv_block_stride_(kv_block_stride), + kv_head_stride_(kv_head_stride), + num_heads_(num_heads), + num_seqs_(num_seqs), + softmax_scale_(softmax_scale), + softcapping_(softcapping) {} + + void eval_cpu(const std::vector& inputs, array& outputs) override { + throw std::runtime_error("NYI"); + } + + void eval_gpu(const std::vector& inputs, array& outputs) override; + + DEFINE_PRINT(PagedAttention); + + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return std::make_tuple( + max_context_len_, + head_size_, + block_size_, + softmax_scale_, + softcapping_); + } + + private: + bool use_v1_; + int max_context_len_; + int head_size_; + int block_size_; + int num_kv_heads_; + int max_num_blocks_per_seq_; + int max_num_partitions_; + int q_stride_; + int kv_block_stride_; + int kv_head_stride_; + int num_heads_; + int num_seqs_; + float softmax_scale_; + std::optional softcapping_ = std::nullopt; +}; + +} // namespace mlx::core::paged_attention \ No newline at end of file