From 0b5c5680f4f000df3ffffcaa5e1ecc7ae789808e Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sat, 31 May 2025 09:16:14 -0400 Subject: [PATCH] Add v1 call --- mlx/backend/metal/kernels.h | 7 + mlx/backend/metal/kernels/paged_attention.h | 1 + mlx/backend/metal/nojit_kernels.cpp | 9 ++ mlx/backend/metal/paged_attention.cpp | 149 +++++++++++++++++++- mlx/paged_attention.cpp | 41 ++++-- mlx/paged_attention_primitives.h | 34 ++++- 6 files changed, 225 insertions(+), 16 deletions(-) diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 6d8864385..db438159d 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -239,6 +239,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( int wn, bool transpose); +MTL::ComputePipelineState* get_paged_attention_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const std::string&); + // Create a GPU kernel template definition for JIT compilation template std::string diff --git a/mlx/backend/metal/kernels/paged_attention.h b/mlx/backend/metal/kernels/paged_attention.h index 257d70eb2..2e11ddc64 100644 --- a/mlx/backend/metal/kernels/paged_attention.h +++ b/mlx/backend/metal/kernels/paged_attention.h @@ -2,6 +2,7 @@ #include #include +#include "mlx/backend/metal/kernels/utils.h" using namespace metal; diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 8da147971..3062385ab 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -286,4 +286,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); } +MTL::ComputePipelineState* get_paged_attention_kernel( + metal::Device& d, + const std::string& kernel_name, + const std::string& hash_name, + const metal::MTLFCList& func_consts, + const std::string&) { + return d.get_kernel(kernel_name, "mlx", hash_name, func_consts); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/paged_attention.cpp b/mlx/backend/metal/paged_attention.cpp index bc4da8788..a1b75f8cc 100644 --- a/mlx/backend/metal/paged_attention.cpp +++ b/mlx/backend/metal/paged_attention.cpp @@ -1,11 +1,8 @@ // Copyright © 2023-2024 Apple Inc. -#include "mlx/backend/common/broadcasting.h" #include "mlx/backend/common/compiled.h" -#include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" -#include "mlx/backend/metal/reduce.h" #include "mlx/backend/metal/utils.h" #include "mlx/paged_attention_primitives.h" #include "mlx/primitives.h" @@ -13,6 +10,125 @@ namespace mlx::core::paged_attention { +void paged_attention_v1( + const array& q, + const array& k_cache, + const array& v_cache, + const array& block_tables, + const array& context_lens, + const int head_size, + const int block_size, + const int num_kv_heads, + const float scale, + const float softcapping, + const int max_context_len, + const int max_num_blocks_per_seq, + const std::optional alibi, + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + const int num_heads, + const int num_seqs, + array& out, + metal::Device& d, + const Stream& s) { + const int partition_size = 0; + const int num_threads = 256; + const int num_simd_lanes = 32; + const bool use_partitioning = false; + const bool use_alibi = alibi.has_value(); + + std::string type_string = get_type_string(q.dtype()); + std::string kname; + kname.reserve(64); + concatenate( + kname, + "paged_attention_", + type_string, + "_hs", + head_size, + "_bs", + block_size, + "_nt", + num_threads, + "_nsl", + num_simd_lanes, + "_ps", + partition_size); + + auto template_def = get_template_definition( + kname, + "paged_attention", + type_string, + head_size, + block_size, + num_threads, + num_simd_lanes, + partition_size); + + // Encode and dispatch kernel + metal::MTLFCList func_consts = { + {use_partitioning, MTL::DataType::DataTypeBool, 10}, + {use_alibi, MTL::DataType::DataTypeBool, 20}, + }; + + std::string hash_name = kname; + auto kernel = get_paged_attention_kernel( + d, kname, hash_name, func_consts, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + auto num_simds = num_threads / num_simd_lanes; + auto max_num_partitions = + (max_context_len + partition_size - 1) / partition_size; + auto logits_size = partition_size * size_of(float32); + auto outputs_size = (num_simds / 2) * head_size * size_of(float32); + auto shared_mem_size = std::max(logits_size, outputs_size); + compute_encoder.set_threadgroup_memory_length(shared_mem_size, 0); + + compute_encoder.set_output_array(out, 2); + compute_encoder.set_input_array(q, 3); + compute_encoder.set_input_array(k_cache, 4); + compute_encoder.set_input_array(v_cache, 5); + + compute_encoder.set_bytes(num_kv_heads, 6); + compute_encoder.set_bytes(scale, 7); + compute_encoder.set_bytes(softcapping, 8); + + compute_encoder.set_input_array(block_tables, 9); + compute_encoder.set_input_array(context_lens, 10); + + compute_encoder.set_bytes(max_num_blocks_per_seq, 11); + + if (use_alibi) { + compute_encoder.set_input_array(alibi.value(), 12); + } + + compute_encoder.set_bytes(q_stride, 13); + compute_encoder.set_bytes(kv_block_stride, 14); + compute_encoder.set_bytes(kv_head_stride, 15); + + MTL::Size grid_dims(num_heads, num_seqs, 1); + MTL::Size group_dims(num_threads, 1, 1); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + return; +} + +void paged_attention_v2( + const array& q, + const array& k_cache, + const array& v_cache, + const array& block_tables, + const array& context_lens, + const int head_size, + const int block_size, + array& out, + metal::Device& d, + const Stream& s) { + throw std::runtime_error("NYI"); +} + void PagedAttention::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); @@ -26,6 +142,31 @@ void PagedAttention::eval_gpu(const std::vector& inputs, array& out) { auto& context_lens = inputs[4]; const auto alibi_slopes = inputs.size() == 6 ? std::optional{inputs[5]} : std::nullopt; - return; + + if (use_v1_) { + paged_attention_v1( + q, + k_cache, + v_cache, + block_tables, + context_lens, + head_size_, + block_size_, + num_kv_heads_, + softmax_scale_, + softcapping_.value_or(0.), + max_context_len_, + max_num_blocks_per_seq_, + alibi_slopes, + q_stride_, + kv_block_stride_, + kv_head_stride_, + num_heads_, + num_seqs_, + out, + d, + s); + } else { + } } } // namespace mlx::core::paged_attention \ No newline at end of file diff --git a/mlx/paged_attention.cpp b/mlx/paged_attention.cpp index 54ab30f05..8300d9250 100644 --- a/mlx/paged_attention.cpp +++ b/mlx/paged_attention.cpp @@ -64,9 +64,9 @@ array paged_attention( const auto& bt_shape = block_tables.shape(); const auto& cl_shape = context_lens.shape(); - int64_t num_seqs = q_shape[0]; - int64_t num_heads = q_shape[1]; - int64_t head_size = q_shape[2]; + int num_seqs = q_shape[0]; + int num_heads = q_shape[1]; + int head_size = q_shape[2]; // Allowed head sizes switch (head_size) { @@ -84,6 +84,8 @@ array paged_attention( "{64, 80, 96, 112, 128, 192, 256}"); } + int max_num_blocks_per_seq = bt_shape[1]; + // block_tables first dimension must match num_seqs if (bt_shape[0] != num_seqs) { std::stringstream ss; @@ -93,11 +95,11 @@ array paged_attention( } // Extract k_cache dimensions - int64_t num_blocks = kc_shape[0]; - int64_t num_kv_heads = kc_shape[1]; - int64_t head_size_kc = kc_shape[2]; - int64_t block_size = kc_shape[3]; - int64_t x = kc_shape[4]; + int num_blocks = kc_shape[0]; + int num_kv_heads = kc_shape[1]; + int head_size_kc = kc_shape[2]; + int block_size = kc_shape[3]; + int x = kc_shape[4]; if (head_size_kc * x != head_size) { std::stringstream ss; @@ -121,8 +123,8 @@ array paged_attention( throw std::invalid_argument(ss.str()); } - constexpr int64_t partition_size = 512; - int64_t max_num_partitions = + constexpr int partition_size = 512; + int max_num_partitions = (max_context_len + partition_size - 1) / partition_size; // ceil‑div bool use_v1 = ((max_num_partitions == 1) || (num_seqs * num_heads > 512)) && (partition_size % block_size == 0); @@ -139,11 +141,28 @@ array paged_attention( inputs.push_back(std::move(alibi_slopes.value())); } + int q_stride = q.strides()[0]; + int kv_block_stride = k_cache.strides()[0]; + int kv_head_stride = k_cache.strides()[1]; + return array( std::move(out_shape), q.dtype(), std::make_shared( - to_stream(s), use_v1, max_context_len, softmax_scale, softcapping), + to_stream(s), + use_v1, + max_context_len, + head_size, + block_size, + num_kv_heads, + softmax_scale, + max_num_blocks_per_seq, + q_stride, + kv_block_stride, + kv_head_stride, + num_heads, + num_seqs, + softcapping), inputs); } diff --git a/mlx/paged_attention_primitives.h b/mlx/paged_attention_primitives.h index 675a61af2..886a161cb 100644 --- a/mlx/paged_attention_primitives.h +++ b/mlx/paged_attention_primitives.h @@ -15,11 +15,29 @@ class PagedAttention : public UnaryPrimitive { Stream stream, bool use_v1, int max_context_len, + int head_size, + int block_size, + int num_kv_heads, + int max_num_blocks_per_seq, + int q_stride, + int kv_block_stride, + int kv_head_stride, + int num_heads, + int num_seqs, float softmax_scale, std::optional softcapping = std::nullopt) : UnaryPrimitive(stream), use_v1_(use_v1), max_context_len_(max_context_len), + head_size_(head_size), + block_size_(block_size), + num_kv_heads_(num_kv_heads), + max_num_blocks_per_seq_(max_num_blocks_per_seq), + q_stride_(q_stride), + kv_block_stride_(kv_block_stride), + kv_head_stride_(kv_head_stride), + num_heads_(num_heads), + num_seqs_(num_seqs), softmax_scale_(softmax_scale), softcapping_(softcapping) {} @@ -34,12 +52,26 @@ class PagedAttention : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { - return std::make_tuple(max_context_len_, softmax_scale_, softcapping_); + return std::make_tuple( + max_context_len_, + head_size_, + block_size_, + softmax_scale_, + softcapping_); } private: bool use_v1_; int max_context_len_; + int head_size_; + int block_size_; + int num_kv_heads_; + int max_num_blocks_per_seq_; + int q_stride_; + int kv_block_stride_; + int kv_head_stride_; + int num_heads_; + int num_seqs_; float softmax_scale_; std::optional softcapping_ = std::nullopt; };