mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Add v1 call
This commit is contained in:
parent
221edc4a65
commit
0b5c5680f4
@ -239,6 +239,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
int wn,
|
||||
bool transpose);
|
||||
|
||||
MTL::ComputePipelineState* get_paged_attention_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const std::string&);
|
||||
|
||||
// Create a GPU kernel template definition for JIT compilation
|
||||
template <typename... Args>
|
||||
std::string
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <metal_simdgroup>
|
||||
#include <metal_stdlib>
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
|
@ -286,4 +286,13 @@ MTL::ComputePipelineState* get_gather_qmm_kernel(
|
||||
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_paged_attention_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const std::string&) {
|
||||
return d.get_kernel(kernel_name, "mlx", hash_name, func_consts);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1,11 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/reduce.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/paged_attention_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
@ -13,6 +10,125 @@
|
||||
|
||||
namespace mlx::core::paged_attention {
|
||||
|
||||
void paged_attention_v1(
|
||||
const array& q,
|
||||
const array& k_cache,
|
||||
const array& v_cache,
|
||||
const array& block_tables,
|
||||
const array& context_lens,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
const float softcapping,
|
||||
const int max_context_len,
|
||||
const int max_num_blocks_per_seq,
|
||||
const std::optional<array> alibi,
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride,
|
||||
const int num_heads,
|
||||
const int num_seqs,
|
||||
array& out,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const int partition_size = 0;
|
||||
const int num_threads = 256;
|
||||
const int num_simd_lanes = 32;
|
||||
const bool use_partitioning = false;
|
||||
const bool use_alibi = alibi.has_value();
|
||||
|
||||
std::string type_string = get_type_string(q.dtype());
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
concatenate(
|
||||
kname,
|
||||
"paged_attention_",
|
||||
type_string,
|
||||
"_hs",
|
||||
head_size,
|
||||
"_bs",
|
||||
block_size,
|
||||
"_nt",
|
||||
num_threads,
|
||||
"_nsl",
|
||||
num_simd_lanes,
|
||||
"_ps",
|
||||
partition_size);
|
||||
|
||||
auto template_def = get_template_definition(
|
||||
kname,
|
||||
"paged_attention",
|
||||
type_string,
|
||||
head_size,
|
||||
block_size,
|
||||
num_threads,
|
||||
num_simd_lanes,
|
||||
partition_size);
|
||||
|
||||
// Encode and dispatch kernel
|
||||
metal::MTLFCList func_consts = {
|
||||
{use_partitioning, MTL::DataType::DataTypeBool, 10},
|
||||
{use_alibi, MTL::DataType::DataTypeBool, 20},
|
||||
};
|
||||
|
||||
std::string hash_name = kname;
|
||||
auto kernel = get_paged_attention_kernel(
|
||||
d, kname, hash_name, func_consts, template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
auto num_simds = num_threads / num_simd_lanes;
|
||||
auto max_num_partitions =
|
||||
(max_context_len + partition_size - 1) / partition_size;
|
||||
auto logits_size = partition_size * size_of(float32);
|
||||
auto outputs_size = (num_simds / 2) * head_size * size_of(float32);
|
||||
auto shared_mem_size = std::max(logits_size, outputs_size);
|
||||
compute_encoder.set_threadgroup_memory_length(shared_mem_size, 0);
|
||||
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
compute_encoder.set_input_array(q, 3);
|
||||
compute_encoder.set_input_array(k_cache, 4);
|
||||
compute_encoder.set_input_array(v_cache, 5);
|
||||
|
||||
compute_encoder.set_bytes(num_kv_heads, 6);
|
||||
compute_encoder.set_bytes(scale, 7);
|
||||
compute_encoder.set_bytes(softcapping, 8);
|
||||
|
||||
compute_encoder.set_input_array(block_tables, 9);
|
||||
compute_encoder.set_input_array(context_lens, 10);
|
||||
|
||||
compute_encoder.set_bytes(max_num_blocks_per_seq, 11);
|
||||
|
||||
if (use_alibi) {
|
||||
compute_encoder.set_input_array(alibi.value(), 12);
|
||||
}
|
||||
|
||||
compute_encoder.set_bytes(q_stride, 13);
|
||||
compute_encoder.set_bytes(kv_block_stride, 14);
|
||||
compute_encoder.set_bytes(kv_head_stride, 15);
|
||||
|
||||
MTL::Size grid_dims(num_heads, num_seqs, 1);
|
||||
MTL::Size group_dims(num_threads, 1, 1);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
return;
|
||||
}
|
||||
|
||||
void paged_attention_v2(
|
||||
const array& q,
|
||||
const array& k_cache,
|
||||
const array& v_cache,
|
||||
const array& block_tables,
|
||||
const array& context_lens,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
array& out,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
|
||||
void PagedAttention::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
@ -26,6 +142,31 @@ void PagedAttention::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& context_lens = inputs[4];
|
||||
const auto alibi_slopes =
|
||||
inputs.size() == 6 ? std::optional{inputs[5]} : std::nullopt;
|
||||
return;
|
||||
|
||||
if (use_v1_) {
|
||||
paged_attention_v1(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
head_size_,
|
||||
block_size_,
|
||||
num_kv_heads_,
|
||||
softmax_scale_,
|
||||
softcapping_.value_or(0.),
|
||||
max_context_len_,
|
||||
max_num_blocks_per_seq_,
|
||||
alibi_slopes,
|
||||
q_stride_,
|
||||
kv_block_stride_,
|
||||
kv_head_stride_,
|
||||
num_heads_,
|
||||
num_seqs_,
|
||||
out,
|
||||
d,
|
||||
s);
|
||||
} else {
|
||||
}
|
||||
}
|
||||
} // namespace mlx::core::paged_attention
|
@ -64,9 +64,9 @@ array paged_attention(
|
||||
const auto& bt_shape = block_tables.shape();
|
||||
const auto& cl_shape = context_lens.shape();
|
||||
|
||||
int64_t num_seqs = q_shape[0];
|
||||
int64_t num_heads = q_shape[1];
|
||||
int64_t head_size = q_shape[2];
|
||||
int num_seqs = q_shape[0];
|
||||
int num_heads = q_shape[1];
|
||||
int head_size = q_shape[2];
|
||||
|
||||
// Allowed head sizes
|
||||
switch (head_size) {
|
||||
@ -84,6 +84,8 @@ array paged_attention(
|
||||
"{64, 80, 96, 112, 128, 192, 256}");
|
||||
}
|
||||
|
||||
int max_num_blocks_per_seq = bt_shape[1];
|
||||
|
||||
// block_tables first dimension must match num_seqs
|
||||
if (bt_shape[0] != num_seqs) {
|
||||
std::stringstream ss;
|
||||
@ -93,11 +95,11 @@ array paged_attention(
|
||||
}
|
||||
|
||||
// Extract k_cache dimensions
|
||||
int64_t num_blocks = kc_shape[0];
|
||||
int64_t num_kv_heads = kc_shape[1];
|
||||
int64_t head_size_kc = kc_shape[2];
|
||||
int64_t block_size = kc_shape[3];
|
||||
int64_t x = kc_shape[4];
|
||||
int num_blocks = kc_shape[0];
|
||||
int num_kv_heads = kc_shape[1];
|
||||
int head_size_kc = kc_shape[2];
|
||||
int block_size = kc_shape[3];
|
||||
int x = kc_shape[4];
|
||||
|
||||
if (head_size_kc * x != head_size) {
|
||||
std::stringstream ss;
|
||||
@ -121,8 +123,8 @@ array paged_attention(
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
|
||||
constexpr int64_t partition_size = 512;
|
||||
int64_t max_num_partitions =
|
||||
constexpr int partition_size = 512;
|
||||
int max_num_partitions =
|
||||
(max_context_len + partition_size - 1) / partition_size; // ceil‑div
|
||||
bool use_v1 = ((max_num_partitions == 1) || (num_seqs * num_heads > 512)) &&
|
||||
(partition_size % block_size == 0);
|
||||
@ -139,11 +141,28 @@ array paged_attention(
|
||||
inputs.push_back(std::move(alibi_slopes.value()));
|
||||
}
|
||||
|
||||
int q_stride = q.strides()[0];
|
||||
int kv_block_stride = k_cache.strides()[0];
|
||||
int kv_head_stride = k_cache.strides()[1];
|
||||
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
q.dtype(),
|
||||
std::make_shared<PagedAttention>(
|
||||
to_stream(s), use_v1, max_context_len, softmax_scale, softcapping),
|
||||
to_stream(s),
|
||||
use_v1,
|
||||
max_context_len,
|
||||
head_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
softmax_scale,
|
||||
max_num_blocks_per_seq,
|
||||
q_stride,
|
||||
kv_block_stride,
|
||||
kv_head_stride,
|
||||
num_heads,
|
||||
num_seqs,
|
||||
softcapping),
|
||||
inputs);
|
||||
}
|
||||
|
||||
|
@ -15,11 +15,29 @@ class PagedAttention : public UnaryPrimitive {
|
||||
Stream stream,
|
||||
bool use_v1,
|
||||
int max_context_len,
|
||||
int head_size,
|
||||
int block_size,
|
||||
int num_kv_heads,
|
||||
int max_num_blocks_per_seq,
|
||||
int q_stride,
|
||||
int kv_block_stride,
|
||||
int kv_head_stride,
|
||||
int num_heads,
|
||||
int num_seqs,
|
||||
float softmax_scale,
|
||||
std::optional<float> softcapping = std::nullopt)
|
||||
: UnaryPrimitive(stream),
|
||||
use_v1_(use_v1),
|
||||
max_context_len_(max_context_len),
|
||||
head_size_(head_size),
|
||||
block_size_(block_size),
|
||||
num_kv_heads_(num_kv_heads),
|
||||
max_num_blocks_per_seq_(max_num_blocks_per_seq),
|
||||
q_stride_(q_stride),
|
||||
kv_block_stride_(kv_block_stride),
|
||||
kv_head_stride_(kv_head_stride),
|
||||
num_heads_(num_heads),
|
||||
num_seqs_(num_seqs),
|
||||
softmax_scale_(softmax_scale),
|
||||
softcapping_(softcapping) {}
|
||||
|
||||
@ -34,12 +52,26 @@ class PagedAttention : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return std::make_tuple(max_context_len_, softmax_scale_, softcapping_);
|
||||
return std::make_tuple(
|
||||
max_context_len_,
|
||||
head_size_,
|
||||
block_size_,
|
||||
softmax_scale_,
|
||||
softcapping_);
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_v1_;
|
||||
int max_context_len_;
|
||||
int head_size_;
|
||||
int block_size_;
|
||||
int num_kv_heads_;
|
||||
int max_num_blocks_per_seq_;
|
||||
int q_stride_;
|
||||
int kv_block_stride_;
|
||||
int kv_head_stride_;
|
||||
int num_heads_;
|
||||
int num_seqs_;
|
||||
float softmax_scale_;
|
||||
std::optional<float> softcapping_ = std::nullopt;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user