SDPA support for small batch (over sequence) queries (#1922)

* batch query sdpa

* batch sdpa for query
This commit is contained in:
Awni Hannun
2025-03-04 10:59:04 -08:00
committed by GitHub
parent 6bcd6bcf70
commit e613d0eaf0
5 changed files with 159 additions and 45 deletions

View File

@@ -5,6 +5,7 @@
using namespace metal;
constant bool has_mask [[function_constant(20)]];
constant bool query_transposed [[function_constant(21)]];
template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector(
@@ -18,9 +19,11 @@ template <typename T, int D, int V = D>
const constant size_t& v_stride,
const constant float& scale,
const device bool* mask [[function_constant(has_mask)]],
const constant int& mask_seq_stride [[function_constant(has_mask)]],
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
const constant int& mask_head_stride [[function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 32;
@@ -41,15 +44,21 @@ template <typename T, int D, int V = D>
threadgroup U sum_exp_scores[BN];
// Adjust positions
const int head_idx = tid.y;
const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + simd_lid * qk_per_thread;
const int o_offset = tpg.x * q_seq_idx + head_idx;
const int q_offset =
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * qk_per_thread;
values += kv_head_idx * v_stride + simd_gid * V + simd_lid * v_per_thread;
if (has_mask) {
mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
out += head_idx * V + simd_gid * v_per_thread;
out += o_offset * V + simd_gid * v_per_thread;
// Read the query and 0 the output accumulator
for (int i = 0; i < qk_per_thread; i++) {
@@ -95,7 +104,7 @@ template <typename T, int D, int V = D>
keys += inner_k_stride;
values += inner_v_stride;
if (has_mask) {
mask += BN * mask_seq_stride;
mask += BN * mask_kv_seq_stride;
}
}
@@ -142,9 +151,11 @@ template <typename T, int D, int V = D>
const constant size_t& v_stride,
const constant float& scale,
const device bool* mask [[function_constant(has_mask)]],
const constant int& mask_seq_stride [[function_constant(has_mask)]],
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
const constant int& mask_head_stride [[function_constant(has_mask)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 8;
@@ -167,20 +178,26 @@ template <typename T, int D, int V = D>
// Adjust positions
const int block_idx = tid.z;
const int head_idx = tid.y;
const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int o_offset = tpg.x * q_seq_idx + head_idx;
const int q_offset =
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + simd_lid * qk_per_thread;
queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
simd_lid * qk_per_thread;
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * V +
simd_lid * v_per_thread;
out += head_idx * blocks * V + block_idx * V + simd_lid * v_per_thread;
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
if (has_mask) {
mask += head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_seq_stride;
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
}
sums += head_idx * blocks + block_idx;
maxs += head_idx * blocks + block_idx;
sums += o_offset * blocks + block_idx;
maxs += o_offset * blocks + block_idx;
// Read the query and 0 the output accumulator
for (int i = 0; i < qk_per_thread; i++) {
@@ -226,7 +243,7 @@ template <typename T, int D, int V = D>
keys += blocks * inner_k_stride;
values += blocks * inner_v_stride;
if (has_mask) {
mask += BN * blocks * mask_seq_stride;
mask += BN * blocks * mask_kv_seq_stride;
}
}
@@ -275,6 +292,7 @@ template <typename T, int D>
const device float* maxs [[buffer(2)]],
device T* out [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 32;
@@ -288,11 +306,14 @@ template <typename T, int D>
threadgroup U outputs[BN * BD];
// Adjust positions
const int head_idx = tid.y;
partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
sums += head_idx * blocks;
maxs += head_idx * blocks;
out += head_idx * D + simd_gid * elem_per_thread;
const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int n_heads = tpg.x;
const int q_offset = n_heads * q_seq_idx + head_idx;
partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
sums += q_offset * blocks;
maxs += q_offset * blocks;
out += q_offset * D + simd_gid * elem_per_thread;
// First everybody reads the max and sum_exp
U max_score = maxs[simd_lid];