From 4d68bd3250baad627105b56435ef75d36c105d3b Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Sat, 31 May 2025 09:48:24 -0400 Subject: [PATCH] Refactor v1/v2 caller code --- mlx/backend/metal/paged_attention.cpp | 371 ++++++++++++++------------ 1 file changed, 198 insertions(+), 173 deletions(-) diff --git a/mlx/backend/metal/paged_attention.cpp b/mlx/backend/metal/paged_attention.cpp index d5cd0b7b2..95cae1222 100644 --- a/mlx/backend/metal/paged_attention.cpp +++ b/mlx/backend/metal/paged_attention.cpp @@ -10,6 +10,159 @@ namespace mlx::core::paged_attention { +static void run_paged_attention( + const array& q, + const array& k_cache, + const array& v_cache, + const array& block_tables, + const array& context_lens, + const int head_size, + const int block_size, + const int num_kv_heads, + const float scale, + const float softcapping, + const int max_context_len, + const int max_num_blocks_per_seq, + const bool use_partitioning, + const std::optional alibi, + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + const int num_heads, + const int num_seqs, + array& out, + metal::Device& d, + const Stream& s) { + const int partition_size = use_partitioning ? 512 : 0; + const int num_threads = 256; + const int num_simd_lanes = 32; + const bool use_alibi = alibi.has_value(); + + std::string type_string = get_type_string(q.dtype()); + std::string kname; + kname.reserve(64); + concatenate( + kname, + "paged_attention_", + type_string, + "_hs", + head_size, + "_bs", + block_size, + "_nt", + num_threads, + "_nsl", + num_simd_lanes, + "_ps", + partition_size); + + auto template_def = get_template_definition( + kname, + "paged_attention", + type_string, + head_size, + block_size, + num_threads, + num_simd_lanes, + partition_size); + + // Encode and dispatch kernel + metal::MTLFCList func_consts = { + {use_partitioning, MTL::DataType::DataTypeBool, 10}, + {use_alibi, MTL::DataType::DataTypeBool, 20}, + }; + + std::string hash_name = kname; + auto kernel = get_paged_attention_kernel( + d, kname, hash_name, func_consts, template_def); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + int local_max_num_partitions = 1; + if (use_partitioning) { + local_max_num_partitions = + (max_context_len + partition_size - 1) / partition_size; + } + + int logits_size = use_partitioning ? partition_size * size_of(float32) : 0; + int outputs_size = use_partitioning + ? ((num_threads / num_simd_lanes) / 2) * head_size * size_of(float32) + : 0; + int shared_mem_size = + use_partitioning ? std::max(logits_size, outputs_size) : 0; + if (use_partitioning) { + compute_encoder.set_threadgroup_memory_length(shared_mem_size, 0); + } + + if (use_partitioning) { + auto tmp_out = array( + {num_seqs, num_heads, local_max_num_partitions, head_size}, float32); + tmp_out.set_data(allocator::malloc(tmp_out.nbytes())); + auto exp_sums = + array({num_seqs, num_heads, local_max_num_partitions}, float32); + exp_sums.set_data(allocator::malloc(exp_sums.nbytes())); + + std::vector temporaries = {tmp_out, exp_sums}; + + compute_encoder.set_output_array(tmp_out, 0); + compute_encoder.set_output_array(exp_sums, 1); + compute_encoder.set_output_array(out, 2); + compute_encoder.set_input_array(q, 3); + compute_encoder.set_input_array(k_cache, 4); + compute_encoder.set_input_array(v_cache, 5); + + compute_encoder.set_bytes(num_kv_heads, 6); + compute_encoder.set_bytes(scale, 7); + compute_encoder.set_bytes(softcapping, 8); + + compute_encoder.set_input_array(block_tables, 9); + compute_encoder.set_input_array(context_lens, 10); + + compute_encoder.set_bytes(max_num_blocks_per_seq, 11); + + if (use_alibi) { + compute_encoder.set_input_array(alibi.value(), 12); + } + + compute_encoder.set_bytes(q_stride, 13); + compute_encoder.set_bytes(kv_block_stride, 14); + compute_encoder.set_bytes(kv_head_stride, 15); + + MTL::Size grid_dims(num_heads, num_seqs, local_max_num_partitions); + MTL::Size group_dims(num_threads, 1, 1); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + d.add_temporaries(std::move(temporaries), s.index); + } else { + compute_encoder.set_output_array(out, 2); + compute_encoder.set_input_array(q, 3); + compute_encoder.set_input_array(k_cache, 4); + compute_encoder.set_input_array(v_cache, 5); + + compute_encoder.set_bytes(num_kv_heads, 6); + compute_encoder.set_bytes(scale, 7); + compute_encoder.set_bytes(softcapping, 8); + + compute_encoder.set_input_array(block_tables, 9); + compute_encoder.set_input_array(context_lens, 10); + + compute_encoder.set_bytes(max_num_blocks_per_seq, 11); + + if (use_alibi) { + compute_encoder.set_input_array(alibi.value(), 12); + } + + compute_encoder.set_bytes(q_stride, 13); + compute_encoder.set_bytes(kv_block_stride, 14); + compute_encoder.set_bytes(kv_head_stride, 15); + + MTL::Size grid_dims(num_heads, num_seqs, 1); + MTL::Size group_dims(num_threads, 1, 1); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + } +} + void paged_attention_v1( const array& q, const array& k_cache, @@ -32,87 +185,29 @@ void paged_attention_v1( array& out, metal::Device& d, const Stream& s) { - const int partition_size = 0; - const int num_threads = 256; - const int num_simd_lanes = 32; - const bool use_partitioning = false; - const bool use_alibi = alibi.has_value(); - - std::string type_string = get_type_string(q.dtype()); - std::string kname; - kname.reserve(64); - concatenate( - kname, - "paged_attention_", - type_string, - "_hs", - head_size, - "_bs", - block_size, - "_nt", - num_threads, - "_nsl", - num_simd_lanes, - "_ps", - partition_size); - - auto template_def = get_template_definition( - kname, - "paged_attention", - type_string, + run_paged_attention( + q, + k_cache, + v_cache, + block_tables, + context_lens, head_size, block_size, - num_threads, - num_simd_lanes, - partition_size); - - // Encode and dispatch kernel - metal::MTLFCList func_consts = { - {use_partitioning, MTL::DataType::DataTypeBool, 10}, - {use_alibi, MTL::DataType::DataTypeBool, 20}, - }; - - std::string hash_name = kname; - auto kernel = get_paged_attention_kernel( - d, kname, hash_name, func_consts, template_def); - auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder.set_compute_pipeline_state(kernel); - - auto num_simds = num_threads / num_simd_lanes; - auto max_num_partitions = - (max_context_len + partition_size - 1) / partition_size; - auto logits_size = partition_size * size_of(float32); - auto outputs_size = (num_simds / 2) * head_size * size_of(float32); - auto shared_mem_size = std::max(logits_size, outputs_size); - compute_encoder.set_threadgroup_memory_length(shared_mem_size, 0); - - compute_encoder.set_output_array(out, 2); - compute_encoder.set_input_array(q, 3); - compute_encoder.set_input_array(k_cache, 4); - compute_encoder.set_input_array(v_cache, 5); - - compute_encoder.set_bytes(num_kv_heads, 6); - compute_encoder.set_bytes(scale, 7); - compute_encoder.set_bytes(softcapping, 8); - - compute_encoder.set_input_array(block_tables, 9); - compute_encoder.set_input_array(context_lens, 10); - - compute_encoder.set_bytes(max_num_blocks_per_seq, 11); - - if (use_alibi) { - compute_encoder.set_input_array(alibi.value(), 12); - } - - compute_encoder.set_bytes(q_stride, 13); - compute_encoder.set_bytes(kv_block_stride, 14); - compute_encoder.set_bytes(kv_head_stride, 15); - - MTL::Size grid_dims(num_heads, num_seqs, 1); - MTL::Size group_dims(num_threads, 1, 1); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - return; + num_kv_heads, + scale, + softcapping, + max_context_len, + max_num_blocks_per_seq, + /*use_partitioning=*/false, + alibi, + q_stride, + kv_block_stride, + kv_head_stride, + num_heads, + num_seqs, + out, + d, + s); } void paged_attention_v2( @@ -128,7 +223,7 @@ void paged_attention_v2( const float softcapping, const int max_context_len, const int max_num_blocks_per_seq, - const int max_num_partitions, + const int /* max_num_partitions */, const std::optional alibi, const int q_stride, const int kv_block_stride, @@ -138,99 +233,29 @@ void paged_attention_v2( array& out, metal::Device& d, const Stream& s) { - const int partition_size = 512; - const int num_threads = 256; - const int num_simd_lanes = 32; - const bool use_partitioning = true; - const bool use_alibi = alibi.has_value(); - - std::string type_string = get_type_string(q.dtype()); - std::string kname; - kname.reserve(64); - concatenate( - kname, - "paged_attention_", - type_string, - "_hs", - head_size, - "_bs", - block_size, - "_nt", - num_threads, - "_nsl", - num_simd_lanes, - "_ps", - partition_size); - - auto template_def = get_template_definition( - kname, - "paged_attention", - type_string, + run_paged_attention( + q, + k_cache, + v_cache, + block_tables, + context_lens, head_size, block_size, - num_threads, - num_simd_lanes, - partition_size); - - // Encode and dispatch kernel - metal::MTLFCList func_consts = { - {use_partitioning, MTL::DataType::DataTypeBool, 10}, - {use_alibi, MTL::DataType::DataTypeBool, 20}, - }; - - std::string hash_name = kname; - auto kernel = get_paged_attention_kernel( - d, kname, hash_name, func_consts, template_def); - auto& compute_encoder = d.get_command_encoder(s.index); - compute_encoder.set_compute_pipeline_state(kernel); - - auto tmp_out = - array({num_seqs, num_heads, max_num_partitions, head_size}, float32); - tmp_out.set_data(allocator::malloc(tmp_out.nbytes())); - auto exp_sums = array({num_seqs, num_heads, max_num_partitions}, float32); - exp_sums.set_data(allocator::malloc(exp_sums.nbytes())); - - std::vector copies = {tmp_out, exp_sums}; - - auto num_simds = num_threads / num_simd_lanes; - auto max_num_partitions = - (max_context_len + partition_size - 1) / partition_size; - auto logits_size = partition_size * size_of(float32); - auto outputs_size = (num_simds / 2) * head_size * size_of(float32); - auto shared_mem_size = std::max(logits_size, outputs_size); - compute_encoder.set_threadgroup_memory_length(shared_mem_size, 0); - - compute_encoder.set_output_array(tmp_out, 0); - compute_encoder.set_output_array(exp_sums, 1); - compute_encoder.set_output_array(out, 2); - compute_encoder.set_input_array(q, 3); - compute_encoder.set_input_array(k_cache, 4); - compute_encoder.set_input_array(v_cache, 5); - - compute_encoder.set_bytes(num_kv_heads, 6); - compute_encoder.set_bytes(scale, 7); - compute_encoder.set_bytes(softcapping, 8); - - compute_encoder.set_input_array(block_tables, 9); - compute_encoder.set_input_array(context_lens, 10); - - compute_encoder.set_bytes(max_num_blocks_per_seq, 11); - - if (use_alibi) { - compute_encoder.set_input_array(alibi.value(), 12); - } - - compute_encoder.set_bytes(q_stride, 13); - compute_encoder.set_bytes(kv_block_stride, 14); - compute_encoder.set_bytes(kv_head_stride, 15); - - MTL::Size grid_dims(num_heads, num_seqs, max_num_partitions); - MTL::Size group_dims(num_threads, 1, 1); - - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); - - d.add_temporaries(std::move(copies), s.index); - return; + num_kv_heads, + scale, + softcapping, + max_context_len, + max_num_blocks_per_seq, + /*use_partitioning=*/true, + alibi, + q_stride, + kv_block_stride, + kv_head_stride, + num_heads, + num_seqs, + out, + d, + s); } void PagedAttention::eval_gpu(const std::vector& inputs, array& out) { @@ -258,7 +283,7 @@ void PagedAttention::eval_gpu(const std::vector& inputs, array& out) { block_size_, num_kv_heads_, softmax_scale_, - softcapping_.value_or(0.), + softcapping_.value_or(1.), max_context_len_, max_num_blocks_per_seq_, alibi_slopes, @@ -281,7 +306,7 @@ void PagedAttention::eval_gpu(const std::vector& inputs, array& out) { block_size_, num_kv_heads_, softmax_scale_, - softcapping_.value_or(0.), + softcapping_.value_or(1.), max_context_len_, max_num_blocks_per_seq_, max_num_partitions_,