mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
SDPA support for small batch (over sequence) queries (#1922)
* batch query sdpa * batch sdpa for query
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
using namespace metal;
|
||||
|
||||
constant bool has_mask [[function_constant(20)]];
|
||||
constant bool query_transposed [[function_constant(21)]];
|
||||
|
||||
template <typename T, int D, int V = D>
|
||||
[[kernel]] void sdpa_vector(
|
||||
@@ -18,9 +19,11 @@ template <typename T, int D, int V = D>
|
||||
const constant size_t& v_stride,
|
||||
const constant float& scale,
|
||||
const device bool* mask [[function_constant(has_mask)]],
|
||||
const constant int& mask_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 tpg [[threadgroups_per_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int BN = 32;
|
||||
@@ -41,15 +44,21 @@ template <typename T, int D, int V = D>
|
||||
threadgroup U sum_exp_scores[BN];
|
||||
|
||||
// Adjust positions
|
||||
const int head_idx = tid.y;
|
||||
const int head_idx = tid.x;
|
||||
const int q_seq_idx = tid.y;
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
queries += head_idx * D + simd_lid * qk_per_thread;
|
||||
const int o_offset = tpg.x * q_seq_idx + head_idx;
|
||||
const int q_offset =
|
||||
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
|
||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * qk_per_thread;
|
||||
values += kv_head_idx * v_stride + simd_gid * V + simd_lid * v_per_thread;
|
||||
if (has_mask) {
|
||||
mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
|
||||
mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
out += head_idx * V + simd_gid * v_per_thread;
|
||||
|
||||
out += o_offset * V + simd_gid * v_per_thread;
|
||||
|
||||
// Read the query and 0 the output accumulator
|
||||
for (int i = 0; i < qk_per_thread; i++) {
|
||||
@@ -95,7 +104,7 @@ template <typename T, int D, int V = D>
|
||||
keys += inner_k_stride;
|
||||
values += inner_v_stride;
|
||||
if (has_mask) {
|
||||
mask += BN * mask_seq_stride;
|
||||
mask += BN * mask_kv_seq_stride;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,9 +151,11 @@ template <typename T, int D, int V = D>
|
||||
const constant size_t& v_stride,
|
||||
const constant float& scale,
|
||||
const device bool* mask [[function_constant(has_mask)]],
|
||||
const constant int& mask_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
|
||||
const constant int& mask_head_stride [[function_constant(has_mask)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 tpg [[threadgroups_per_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int BN = 8;
|
||||
@@ -167,20 +178,26 @@ template <typename T, int D, int V = D>
|
||||
|
||||
// Adjust positions
|
||||
const int block_idx = tid.z;
|
||||
const int head_idx = tid.y;
|
||||
const int head_idx = tid.x;
|
||||
const int q_seq_idx = tid.y;
|
||||
const int o_offset = tpg.x * q_seq_idx + head_idx;
|
||||
const int q_offset =
|
||||
query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
|
||||
const int kv_head_idx = head_idx / gqa_factor;
|
||||
queries += head_idx * D + simd_lid * qk_per_thread;
|
||||
|
||||
queries += q_offset * D + simd_lid * qk_per_thread;
|
||||
keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
|
||||
simd_lid * qk_per_thread;
|
||||
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * V +
|
||||
simd_lid * v_per_thread;
|
||||
out += head_idx * blocks * V + block_idx * V + simd_lid * v_per_thread;
|
||||
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
|
||||
if (has_mask) {
|
||||
mask += head_idx * mask_head_stride +
|
||||
(block_idx * BN + simd_gid) * mask_seq_stride;
|
||||
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
|
||||
q_seq_idx * mask_q_seq_stride;
|
||||
}
|
||||
sums += head_idx * blocks + block_idx;
|
||||
maxs += head_idx * blocks + block_idx;
|
||||
sums += o_offset * blocks + block_idx;
|
||||
maxs += o_offset * blocks + block_idx;
|
||||
|
||||
// Read the query and 0 the output accumulator
|
||||
for (int i = 0; i < qk_per_thread; i++) {
|
||||
@@ -226,7 +243,7 @@ template <typename T, int D, int V = D>
|
||||
keys += blocks * inner_k_stride;
|
||||
values += blocks * inner_v_stride;
|
||||
if (has_mask) {
|
||||
mask += BN * blocks * mask_seq_stride;
|
||||
mask += BN * blocks * mask_kv_seq_stride;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -275,6 +292,7 @@ template <typename T, int D>
|
||||
const device float* maxs [[buffer(2)]],
|
||||
device T* out [[buffer(3)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 tpg [[threadgroups_per_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int BN = 32;
|
||||
@@ -288,11 +306,14 @@ template <typename T, int D>
|
||||
threadgroup U outputs[BN * BD];
|
||||
|
||||
// Adjust positions
|
||||
const int head_idx = tid.y;
|
||||
partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
|
||||
sums += head_idx * blocks;
|
||||
maxs += head_idx * blocks;
|
||||
out += head_idx * D + simd_gid * elem_per_thread;
|
||||
const int head_idx = tid.x;
|
||||
const int q_seq_idx = tid.y;
|
||||
const int n_heads = tpg.x;
|
||||
const int q_offset = n_heads * q_seq_idx + head_idx;
|
||||
partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
|
||||
sums += q_offset * blocks;
|
||||
maxs += q_offset * blocks;
|
||||
out += q_offset * D + simd_gid * elem_per_thread;
|
||||
|
||||
// First everybody reads the max and sum_exp
|
||||
U max_score = maxs[simd_lid];
|
||||
|
||||
Reference in New Issue
Block a user