Add v1 call

This commit is contained in:
Eric Buehler 2025-05-31 09:16:14 -04:00
parent 221edc4a65
commit 0b5c5680f4
6 changed files with 225 additions and 16 deletions

View File

@ -239,6 +239,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
int wn,
bool transpose);
MTL::ComputePipelineState* get_paged_attention_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const std::string&);
// Create a GPU kernel template definition for JIT compilation
template <typename... Args>
std::string

View File

@ -2,6 +2,7 @@
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;

View File

@ -286,4 +286,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
MTL::ComputePipelineState* get_paged_attention_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const std::string&) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
} // namespace mlx::core

View File

@ -1,11 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/broadcasting.h"
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/reduce.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/paged_attention_primitives.h"
#include "mlx/primitives.h"
@ -13,6 +10,125 @@
namespace mlx::core::paged_attention {
void paged_attention_v1(
const array& q,
const array& k_cache,
const array& v_cache,
const array& block_tables,
const array& context_lens,
const int head_size,
const int block_size,
const int num_kv_heads,
const float scale,
const float softcapping,
const int max_context_len,
const int max_num_blocks_per_seq,
const std::optional<array> alibi,
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const int num_heads,
const int num_seqs,
array& out,
metal::Device& d,
const Stream& s) {
const int partition_size = 0;
const int num_threads = 256;
const int num_simd_lanes = 32;
const bool use_partitioning = false;
const bool use_alibi = alibi.has_value();
std::string type_string = get_type_string(q.dtype());
std::string kname;
kname.reserve(64);
concatenate(
kname,
"paged_attention_",
type_string,
"_hs",
head_size,
"_bs",
block_size,
"_nt",
num_threads,
"_nsl",
num_simd_lanes,
"_ps",
partition_size);
auto template_def = get_template_definition(
kname,
"paged_attention",
type_string,
head_size,
block_size,
num_threads,
num_simd_lanes,
partition_size);
// Encode and dispatch kernel
metal::MTLFCList func_consts = {
{use_partitioning, MTL::DataType::DataTypeBool, 10},
{use_alibi, MTL::DataType::DataTypeBool, 20},
};
std::string hash_name = kname;
auto kernel = get_paged_attention_kernel(
d, kname, hash_name, func_consts, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
auto num_simds = num_threads / num_simd_lanes;
auto max_num_partitions =
(max_context_len + partition_size - 1) / partition_size;
auto logits_size = partition_size * size_of(float32);
auto outputs_size = (num_simds / 2) * head_size * size_of(float32);
auto shared_mem_size = std::max(logits_size, outputs_size);
compute_encoder.set_threadgroup_memory_length(shared_mem_size, 0);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_input_array(q, 3);
compute_encoder.set_input_array(k_cache, 4);
compute_encoder.set_input_array(v_cache, 5);
compute_encoder.set_bytes(num_kv_heads, 6);
compute_encoder.set_bytes(scale, 7);
compute_encoder.set_bytes(softcapping, 8);
compute_encoder.set_input_array(block_tables, 9);
compute_encoder.set_input_array(context_lens, 10);
compute_encoder.set_bytes(max_num_blocks_per_seq, 11);
if (use_alibi) {
compute_encoder.set_input_array(alibi.value(), 12);
}
compute_encoder.set_bytes(q_stride, 13);
compute_encoder.set_bytes(kv_block_stride, 14);
compute_encoder.set_bytes(kv_head_stride, 15);
MTL::Size grid_dims(num_heads, num_seqs, 1);
MTL::Size group_dims(num_threads, 1, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
return;
}
void paged_attention_v2(
const array& q,
const array& k_cache,
const array& v_cache,
const array& block_tables,
const array& context_lens,
const int head_size,
const int block_size,
array& out,
metal::Device& d,
const Stream& s) {
throw std::runtime_error("NYI");
}
void PagedAttention::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
@ -26,6 +142,31 @@ void PagedAttention::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& context_lens = inputs[4];
const auto alibi_slopes =
inputs.size() == 6 ? std::optional{inputs[5]} : std::nullopt;
return;
if (use_v1_) {
paged_attention_v1(
q,
k_cache,
v_cache,
block_tables,
context_lens,
head_size_,
block_size_,
num_kv_heads_,
softmax_scale_,
softcapping_.value_or(0.),
max_context_len_,
max_num_blocks_per_seq_,
alibi_slopes,
q_stride_,
kv_block_stride_,
kv_head_stride_,
num_heads_,
num_seqs_,
out,
d,
s);
} else {
}
}
} // namespace mlx::core::paged_attention

View File

@ -64,9 +64,9 @@ array paged_attention(
const auto& bt_shape = block_tables.shape();
const auto& cl_shape = context_lens.shape();
int64_t num_seqs = q_shape[0];
int64_t num_heads = q_shape[1];
int64_t head_size = q_shape[2];
int num_seqs = q_shape[0];
int num_heads = q_shape[1];
int head_size = q_shape[2];
// Allowed head sizes
switch (head_size) {
@ -84,6 +84,8 @@ array paged_attention(
"{64, 80, 96, 112, 128, 192, 256}");
}
int max_num_blocks_per_seq = bt_shape[1];
// block_tables first dimension must match num_seqs
if (bt_shape[0] != num_seqs) {
std::stringstream ss;
@ -93,11 +95,11 @@ array paged_attention(
}
// Extract k_cache dimensions
int64_t num_blocks = kc_shape[0];
int64_t num_kv_heads = kc_shape[1];
int64_t head_size_kc = kc_shape[2];
int64_t block_size = kc_shape[3];
int64_t x = kc_shape[4];
int num_blocks = kc_shape[0];
int num_kv_heads = kc_shape[1];
int head_size_kc = kc_shape[2];
int block_size = kc_shape[3];
int x = kc_shape[4];
if (head_size_kc * x != head_size) {
std::stringstream ss;
@ -121,8 +123,8 @@ array paged_attention(
throw std::invalid_argument(ss.str());
}
constexpr int64_t partition_size = 512;
int64_t max_num_partitions =
constexpr int partition_size = 512;
int max_num_partitions =
(max_context_len + partition_size - 1) / partition_size; // ceildiv
bool use_v1 = ((max_num_partitions == 1) || (num_seqs * num_heads > 512)) &&
(partition_size % block_size == 0);
@ -139,11 +141,28 @@ array paged_attention(
inputs.push_back(std::move(alibi_slopes.value()));
}
int q_stride = q.strides()[0];
int kv_block_stride = k_cache.strides()[0];
int kv_head_stride = k_cache.strides()[1];
return array(
std::move(out_shape),
q.dtype(),
std::make_shared<PagedAttention>(
to_stream(s), use_v1, max_context_len, softmax_scale, softcapping),
to_stream(s),
use_v1,
max_context_len,
head_size,
block_size,
num_kv_heads,
softmax_scale,
max_num_blocks_per_seq,
q_stride,
kv_block_stride,
kv_head_stride,
num_heads,
num_seqs,
softcapping),
inputs);
}

View File

@ -15,11 +15,29 @@ class PagedAttention : public UnaryPrimitive {
Stream stream,
bool use_v1,
int max_context_len,
int head_size,
int block_size,
int num_kv_heads,
int max_num_blocks_per_seq,
int q_stride,
int kv_block_stride,
int kv_head_stride,
int num_heads,
int num_seqs,
float softmax_scale,
std::optional<float> softcapping = std::nullopt)
: UnaryPrimitive(stream),
use_v1_(use_v1),
max_context_len_(max_context_len),
head_size_(head_size),
block_size_(block_size),
num_kv_heads_(num_kv_heads),
max_num_blocks_per_seq_(max_num_blocks_per_seq),
q_stride_(q_stride),
kv_block_stride_(kv_block_stride),
kv_head_stride_(kv_head_stride),
num_heads_(num_heads),
num_seqs_(num_seqs),
softmax_scale_(softmax_scale),
softcapping_(softcapping) {}
@ -34,12 +52,26 @@ class PagedAttention : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return std::make_tuple(max_context_len_, softmax_scale_, softcapping_);
return std::make_tuple(
max_context_len_,
head_size_,
block_size_,
softmax_scale_,
softcapping_);
}
private:
bool use_v1_;
int max_context_len_;
int head_size_;
int block_size_;
int num_kv_heads_;
int max_num_blocks_per_seq_;
int q_stride_;
int kv_block_stride_;
int kv_head_stride_;
int num_heads_;
int num_seqs_;
float softmax_scale_;
std::optional<float> softcapping_ = std::nullopt;
};