This commit is contained in:
Eric Buehler 2025-06-16 22:37:33 +02:00 committed by GitHub
commit 14d22ddedb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1957 additions and 0 deletions

View File

@ -12,6 +12,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/paged_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp

View File

@ -102,6 +102,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/paged_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp

View File

@ -241,6 +241,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
int wn,
bool transpose);
MTL::ComputePipelineState* get_paged_attention_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const std::string&);
// Create a GPU kernel template definition for JIT compilation
template <typename... Args>
std::string

View File

@ -109,6 +109,7 @@ if(NOT MLX_METAL_JIT)
reduction/reduce_row.h)
build_kernel(quantized quantized.h ${STEEL_HEADERS})
build_kernel(scan scan.h)
build_kernel(paged_attention paged_attention.h)
build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)
build_kernel(sort sort.h)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,131 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/paged_attention.h"
#include "mlx/backend/metal/kernels/utils.h"
#define instantiate_paged_attention_inner( \
type, head_size, block_size, num_threads, num_simd_lanes, partition_size) \
template \
[[host_name("paged_attention_" #type "_hs" #head_size "_bs" #block_size \
"_nt" #num_threads "_nsl" #num_simd_lanes \
"_ps" #partition_size)]] [[kernel]] void \
paged_attention< \
type, \
head_size, \
block_size, \
num_threads, \
num_simd_lanes, \
partition_size>( \
device float* exp_sums \
[[buffer(0), function_constant(use_partitioning)]], \
device float* max_logits \
[[buffer(1), function_constant(use_partitioning)]], \
device type* out [[buffer(2)]], \
device const type* q [[buffer(3)]], \
device const type* k_cache [[buffer(4)]], \
device const type* v_cache [[buffer(5)]], \
const constant int& num_kv_heads [[buffer(6)]], \
const constant float& scale [[buffer(7)]], \
const constant float& softcapping [[buffer(8)]], \
device const uint32_t* block_tables [[buffer(9)]], \
device const uint32_t* context_lens [[buffer(10)]], \
const constant int& max_num_blocks_per_seq [[buffer(11)]], \
device const float* alibi_slopes \
[[buffer(12), function_constant(use_alibi)]], \
const constant int& q_stride [[buffer(13)]], \
const constant int& kv_block_stride [[buffer(14)]], \
const constant int& kv_head_stride [[buffer(15)]], \
threadgroup char* shared_mem [[threadgroup(0)]], \
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
uint3 thread_position_in_threadgroup \
[[thread_position_in_threadgroup]], \
uint simd_tid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_paged_attention_v2_reduce_inner( \
type, head_size, num_threads, num_simd_lanes, partition_size) \
template [[host_name("paged_attention_v2_reduce_" #type "_hs" #head_size \
"_nt" #num_threads "_nsl" #num_simd_lanes \
"_ps" #partition_size)]] [[kernel]] void \
paged_attention_v2_reduce< \
type, \
head_size, \
num_threads, \
num_simd_lanes, \
partition_size>( \
device type * out [[buffer(0)]], \
const device float* exp_sums [[buffer(1)]], \
const device float* max_logits [[buffer(2)]], \
const device type* tmp_out [[buffer(3)]], \
device uint32_t* context_lens [[buffer(4)]], \
const constant int& max_num_partitions [[buffer(5)]], \
threadgroup char* shared_mem [[threadgroup(0)]], \
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \
uint3 threads_per_threadgroup [[threads_per_threadgroup]], \
uint simd_tid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_paged_attention_heads( \
type, block_size, num_threads, num_simd_lanes, partition_size) \
instantiate_paged_attention_inner( \
type, 64, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 80, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 96, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 112, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 128, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 192, block_size, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_inner( \
type, 256, block_size, num_threads, num_simd_lanes, partition_size);
#define instantiate_paged_attention_v2_reduce_heads( \
type, num_threads, num_simd_lanes, partition_size) \
instantiate_paged_attention_v2_reduce_inner( \
type, 64, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 80, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 96, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 112, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 128, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 192, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_v2_reduce_inner( \
type, 256, num_threads, num_simd_lanes, partition_size);
#define instantiate_paged_attention_block_size( \
type, num_threads, num_simd_lanes, partition_size) \
instantiate_paged_attention_heads( \
type, 8, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_heads( \
type, 16, num_threads, num_simd_lanes, partition_size); \
instantiate_paged_attention_heads( \
type, 32, num_threads, num_simd_lanes, partition_size);
// TODO: tune num_threads = 256
// NOTE: partition_size = 0
#define instantiate_paged_attention_v1(type, num_simd_lanes) \
instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 0);
// TODO: tune num_threads = 256
// NOTE: partition_size = 512
#define instantiate_paged_attention_v2(type, num_simd_lanes) \
instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 512); \
instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512);
instantiate_paged_attention_v1(float, 32);
instantiate_paged_attention_v1(bfloat16_t, 32);
instantiate_paged_attention_v1(half, 32);
instantiate_paged_attention_v2(float, 32);
instantiate_paged_attention_v2(bfloat16_t, 32);
instantiate_paged_attention_v2(half, 32);

View File

@ -288,4 +288,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
return d.get_kernel(kernel_name, hash_name, func_consts);
}
MTL::ComputePipelineState* get_paged_attention_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const std::string&) {
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
}
} // namespace mlx::core

View File

@ -0,0 +1,324 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/paged_attention_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core::paged_attention {
static void run_paged_attention(
const array& q,
const array& k_cache,
const array& v_cache,
const array& block_tables,
const array& context_lens,
const int head_size,
const int block_size,
const int num_kv_heads,
const float scale,
const float softcapping,
const int max_context_len,
const int max_num_blocks_per_seq,
const bool use_partitioning,
const std::optional<array> alibi,
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const int num_heads,
const int num_seqs,
array& out,
metal::Device& d,
const Stream& s) {
const int partition_size = use_partitioning ? 512 : 0;
const int num_threads = 256;
const int num_simd_lanes = 32;
const bool use_alibi = alibi.has_value();
std::string type_string = get_type_string(q.dtype());
std::string kname;
kname.reserve(64);
concatenate(
kname,
"paged_attention_",
type_string,
"_hs",
head_size,
"_bs",
block_size,
"_nt",
num_threads,
"_nsl",
num_simd_lanes,
"_ps",
partition_size);
auto template_def = get_template_definition(
kname,
"paged_attention",
type_string,
head_size,
block_size,
num_threads,
num_simd_lanes,
partition_size);
// Encode and dispatch kernel
metal::MTLFCList func_consts = {
{use_partitioning, MTL::DataType::DataTypeBool, 10},
{use_alibi, MTL::DataType::DataTypeBool, 20},
};
std::string hash_name = kname;
auto kernel = get_paged_attention_kernel(
d, kname, hash_name, func_consts, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
int local_max_num_partitions = 1;
if (use_partitioning) {
local_max_num_partitions =
(max_context_len + partition_size - 1) / partition_size;
}
int logits_size = use_partitioning ? partition_size * size_of(float32) : 0;
int outputs_size = use_partitioning
? ((num_threads / num_simd_lanes) / 2) * head_size * size_of(float32)
: 0;
int shared_mem_size =
use_partitioning ? std::max(logits_size, outputs_size) : 0;
if (use_partitioning) {
compute_encoder.set_threadgroup_memory_length(shared_mem_size, 0);
}
if (use_partitioning) {
auto tmp_out = array(
{num_seqs, num_heads, local_max_num_partitions, head_size}, float32);
tmp_out.set_data(allocator::malloc(tmp_out.nbytes()));
auto exp_sums =
array({num_seqs, num_heads, local_max_num_partitions}, float32);
exp_sums.set_data(allocator::malloc(exp_sums.nbytes()));
std::vector<array> temporaries = {tmp_out, exp_sums};
compute_encoder.set_output_array(tmp_out, 0);
compute_encoder.set_output_array(exp_sums, 1);
compute_encoder.set_output_array(out, 2);
compute_encoder.set_input_array(q, 3);
compute_encoder.set_input_array(k_cache, 4);
compute_encoder.set_input_array(v_cache, 5);
compute_encoder.set_bytes(num_kv_heads, 6);
compute_encoder.set_bytes(scale, 7);
compute_encoder.set_bytes(softcapping, 8);
compute_encoder.set_input_array(block_tables, 9);
compute_encoder.set_input_array(context_lens, 10);
compute_encoder.set_bytes(max_num_blocks_per_seq, 11);
if (use_alibi) {
compute_encoder.set_input_array(alibi.value(), 12);
}
compute_encoder.set_bytes(q_stride, 13);
compute_encoder.set_bytes(kv_block_stride, 14);
compute_encoder.set_bytes(kv_head_stride, 15);
MTL::Size grid_dims(num_heads, num_seqs, local_max_num_partitions);
MTL::Size group_dims(num_threads, 1, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(temporaries), s.index);
} else {
compute_encoder.set_output_array(out, 2);
compute_encoder.set_input_array(q, 3);
compute_encoder.set_input_array(k_cache, 4);
compute_encoder.set_input_array(v_cache, 5);
compute_encoder.set_bytes(num_kv_heads, 6);
compute_encoder.set_bytes(scale, 7);
compute_encoder.set_bytes(softcapping, 8);
compute_encoder.set_input_array(block_tables, 9);
compute_encoder.set_input_array(context_lens, 10);
compute_encoder.set_bytes(max_num_blocks_per_seq, 11);
if (use_alibi) {
compute_encoder.set_input_array(alibi.value(), 12);
}
compute_encoder.set_bytes(q_stride, 13);
compute_encoder.set_bytes(kv_block_stride, 14);
compute_encoder.set_bytes(kv_head_stride, 15);
MTL::Size grid_dims(num_heads, num_seqs, 1);
MTL::Size group_dims(num_threads, 1, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
}
void paged_attention_v1(
const array& q,
const array& k_cache,
const array& v_cache,
const array& block_tables,
const array& context_lens,
const int head_size,
const int block_size,
const int num_kv_heads,
const float scale,
const float softcapping,
const int max_context_len,
const int max_num_blocks_per_seq,
const std::optional<array> alibi,
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const int num_heads,
const int num_seqs,
array& out,
metal::Device& d,
const Stream& s) {
run_paged_attention(
q,
k_cache,
v_cache,
block_tables,
context_lens,
head_size,
block_size,
num_kv_heads,
scale,
softcapping,
max_context_len,
max_num_blocks_per_seq,
/*use_partitioning=*/false,
alibi,
q_stride,
kv_block_stride,
kv_head_stride,
num_heads,
num_seqs,
out,
d,
s);
}
void paged_attention_v2(
const array& q,
const array& k_cache,
const array& v_cache,
const array& block_tables,
const array& context_lens,
const int head_size,
const int block_size,
const int num_kv_heads,
const float scale,
const float softcapping,
const int max_context_len,
const int max_num_blocks_per_seq,
const int /* max_num_partitions */,
const std::optional<array> alibi,
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const int num_heads,
const int num_seqs,
array& out,
metal::Device& d,
const Stream& s) {
run_paged_attention(
q,
k_cache,
v_cache,
block_tables,
context_lens,
head_size,
block_size,
num_kv_heads,
scale,
softcapping,
max_context_len,
max_num_blocks_per_seq,
/*use_partitioning=*/true,
alibi,
q_stride,
kv_block_stride,
kv_head_stride,
num_heads,
num_seqs,
out,
d,
s);
}
void PagedAttention::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
out.set_data(allocator::malloc(out.nbytes()));
auto& q = inputs[0];
auto& k_cache = inputs[1];
auto& v_cache = inputs[2];
auto& block_tables = inputs[3];
auto& context_lens = inputs[4];
const auto alibi_slopes =
inputs.size() == 6 ? std::optional{inputs[5]} : std::nullopt;
if (use_v1_) {
paged_attention_v1(
q,
k_cache,
v_cache,
block_tables,
context_lens,
head_size_,
block_size_,
num_kv_heads_,
softmax_scale_,
softcapping_.value_or(1.),
max_context_len_,
max_num_blocks_per_seq_,
alibi_slopes,
q_stride_,
kv_block_stride_,
kv_head_stride_,
num_heads_,
num_seqs_,
out,
d,
s);
} else {
paged_attention_v2(
q,
k_cache,
v_cache,
block_tables,
context_lens,
head_size_,
block_size_,
num_kv_heads_,
softmax_scale_,
softcapping_.value_or(1.),
max_context_len_,
max_num_blocks_per_seq_,
max_num_partitions_,
alibi_slopes,
q_stride_,
kv_block_stride_,
kv_head_stride_,
num_heads_,
num_seqs_,
out,
d,
s);
}
}
} // namespace mlx::core::paged_attention

View File

@ -17,6 +17,7 @@
#include "mlx/linalg.h"
#include "mlx/memory.h"
#include "mlx/ops.h"
#include "mlx/paged_attention.h"
#include "mlx/random.h"
#include "mlx/stream.h"
#include "mlx/transforms.h"

170
mlx/paged_attention.cpp Normal file
View File

@ -0,0 +1,170 @@
// Copyright © 2023-2024 Apple Inc.
// Required for using M_PI in MSVC.
#define _USE_MATH_DEFINES
#include <algorithm>
#include <climits>
#include <cmath>
#include <numeric>
#include <set>
#include <sstream>
#include <stdexcept>
#include "mlx/paged_attention_primitives.h"
#include "mlx/utils.h"
namespace mlx::core::paged_attention {
array paged_attention(
const array& q,
const array& k_cache,
const array& v_cache,
const array& block_tables,
const array& context_lens,
int max_context_len,
float softmax_scale,
std::optional<array> alibi_slopes = std::nullopt,
std::optional<float> softcapping = std::nullopt,
StreamOrDevice s_ = {}) {
auto s = to_stream(s_);
// supported dtypes
if (!issubdtype(q.dtype(), floating)) {
throw std::invalid_argument(
"[paged_attention] Only real floating types are supported");
}
if (!(q.dtype() == k_cache.dtype() && k_cache.dtype() == v_cache.dtype())) {
throw std::invalid_argument(
"[paged_attention] q/k_cache/v_cache dtype must match");
}
if (!(block_tables.dtype() == uint32 && context_lens.dtype() == uint32)) {
throw std::invalid_argument(
"[paged_attention] block_tables/context_lens dtype must be uint32");
}
// rank checks
if (q.ndim() != 3)
throw std::invalid_argument("[paged_attention] `q` must be rank-3");
if (k_cache.ndim() != 5)
throw std::invalid_argument("[paged_attention] `k_cache` must be rank-5");
if (v_cache.ndim() != 4)
throw std::invalid_argument("[paged_attention] `v_cache` must be rank-4");
if (block_tables.ndim() != 2)
throw std::invalid_argument(
"[paged_attention] `block_tables` must be rank-2");
if (context_lens.ndim() != 1)
throw std::invalid_argument(
"[paged_attention] `context_lens` must be rank-1");
// 4. Shape consistency
const auto& q_shape = q.shape(); // [num_seqs, num_heads, head_size]
const auto& kc_shape = k_cache.shape();
const auto& vc_shape = v_cache.shape();
const auto& bt_shape = block_tables.shape();
const auto& cl_shape = context_lens.shape();
int num_seqs = q_shape[0];
int num_heads = q_shape[1];
int head_size = q_shape[2];
// Allowed head sizes
switch (head_size) {
case 64:
case 80:
case 96:
case 112:
case 128:
case 192:
case 256:
break;
default:
throw std::invalid_argument(
"[paged_attention] `head_size` must be one of "
"{64, 80, 96, 112, 128, 192, 256}");
}
int max_num_blocks_per_seq = bt_shape[1];
// block_tables first dimension must match num_seqs
if (bt_shape[0] != num_seqs) {
std::stringstream ss;
ss << "[paged_attention] block_tables.shape[0] (" << bt_shape[0]
<< ") must equal q.shape[0] (" << num_seqs << ")";
throw std::invalid_argument(ss.str());
}
// Extract k_cache dimensions
int num_blocks = kc_shape[0];
int num_kv_heads = kc_shape[1];
int head_size_kc = kc_shape[2];
int block_size = kc_shape[3];
int x = kc_shape[4];
if (head_size_kc * x != head_size) {
std::stringstream ss;
ss << "[paged_attention] k_cache head_size (" << head_size_kc << " * " << x
<< ") must equal q head_size (" << head_size << ")";
throw std::invalid_argument(ss.str());
}
// v_cache must match the derived dimensions
if (!(vc_shape[0] == num_blocks && vc_shape[1] == num_kv_heads &&
vc_shape[2] == head_size && vc_shape[3] == block_size)) {
throw std::invalid_argument(
"[paged_attention] `v_cache` shape mismatch with `k_cache`/`q`");
}
// context_lens length must match num_seqs
if (cl_shape[0] != num_seqs) {
std::stringstream ss;
ss << "paged_attention: context_lens length (" << cl_shape[0]
<< ") must equal q.shape[0] (" << num_seqs << ")";
throw std::invalid_argument(ss.str());
}
constexpr int partition_size = 512;
int max_num_partitions =
(max_context_len + partition_size - 1) / partition_size; // ceildiv
bool use_v1 = ((max_num_partitions == 1) || (num_seqs * num_heads > 512)) &&
(partition_size % block_size == 0);
auto out_shape = q.shape();
auto inputs = std::vector{
std::move(q),
std::move(k_cache),
std::move(v_cache),
std::move(block_tables),
std::move(context_lens)};
if (alibi_slopes.has_value()) {
inputs.push_back(std::move(alibi_slopes.value()));
}
int q_stride = q.strides()[0];
int kv_block_stride = k_cache.strides()[0];
int kv_head_stride = k_cache.strides()[1];
return array(
std::move(out_shape),
q.dtype(),
std::make_shared<PagedAttention>(
to_stream(s),
use_v1,
max_context_len,
head_size,
block_size,
num_kv_heads,
softmax_scale,
max_num_blocks_per_seq,
max_num_partitions,
q_stride,
kv_block_stride,
kv_head_stride,
num_heads,
num_seqs,
softcapping),
inputs);
}
} // namespace mlx::core::paged_attention

34
mlx/paged_attention.h Normal file
View File

@ -0,0 +1,34 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <optional>
#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/stream.h"
#include "mlx/utils.h"
namespace mlx::core::paged_attention {
/**
* \defgroup ops Paged attention operations
* @{
*/
/** PagedAttention operation. */
array paged_attention(
const array& q,
const array& k_cache,
const array& v_cache,
const array& block_tables,
const array& context_lens,
int max_context_len,
float softmax_scale,
std::optional<array> alibi_slopes = std::nullopt,
std::optional<float> softcapping = std::nullopt,
StreamOrDevice s_ = {});
/** @} */
} // namespace mlx::core::paged_attention

View File

@ -0,0 +1,82 @@
// Copyright © 2023-2024 Apple Inc.
// Required for using M_PI in MSVC.
#define _USE_MATH_DEFINES
#include <optional>
#include "mlx/primitives.h"
namespace mlx::core::paged_attention {
class PagedAttention : public UnaryPrimitive {
public:
explicit PagedAttention(
Stream stream,
bool use_v1,
int max_context_len,
int head_size,
int block_size,
int num_kv_heads,
int max_num_blocks_per_seq,
int max_num_partitions,
int q_stride,
int kv_block_stride,
int kv_head_stride,
int num_heads,
int num_seqs,
float softmax_scale,
std::optional<float> softcapping = std::nullopt)
: UnaryPrimitive(stream),
use_v1_(use_v1),
max_context_len_(max_context_len),
head_size_(head_size),
block_size_(block_size),
num_kv_heads_(num_kv_heads),
max_num_blocks_per_seq_(max_num_blocks_per_seq),
max_num_partitions_(max_num_partitions),
q_stride_(q_stride),
kv_block_stride_(kv_block_stride),
kv_head_stride_(kv_head_stride),
num_heads_(num_heads),
num_seqs_(num_seqs),
softmax_scale_(softmax_scale),
softcapping_(softcapping) {}
void eval_cpu(const std::vector<array>& inputs, array& outputs) override {
throw std::runtime_error("NYI");
}
void eval_gpu(const std::vector<array>& inputs, array& outputs) override;
DEFINE_PRINT(PagedAttention);
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return std::make_tuple(
max_context_len_,
head_size_,
block_size_,
softmax_scale_,
softcapping_);
}
private:
bool use_v1_;
int max_context_len_;
int head_size_;
int block_size_;
int num_kv_heads_;
int max_num_blocks_per_seq_;
int max_num_partitions_;
int q_stride_;
int kv_block_stride_;
int kv_head_stride_;
int num_heads_;
int num_seqs_;
float softmax_scale_;
std::optional<float> softcapping_ = std::nullopt;
};
} // namespace mlx::core::paged_attention