9 const device T* queries [[buffer(0)]],
10 const device T* keys [[buffer(1)]],
11 const device T* values [[buffer(2)]],
12 device T* out [[buffer(3)]],
13 const constant
int& gqa_factor,
14 const constant
int& N,
15 const constant
size_t& k_stride,
16 const constant
float& scale,
17 uint3 tid [[threadgroup_position_in_grid]],
18 uint simd_gid [[simdgroup_index_in_threadgroup]],
19 uint simd_lid [[thread_index_in_simdgroup]]) {
20 constexpr int BN = 32;
21 constexpr int BD = 32;
22 constexpr int elem_per_thread = D / BD;
24 const int stride = BN * D;
28 thread U q[elem_per_thread];
29 thread U k[elem_per_thread];
30 thread U o[elem_per_thread];
32 threadgroup U outputs[BN * BD];
33 threadgroup U max_scores[BN];
34 threadgroup U sum_exp_scores[BN];
37 const int head_idx = tid.y;
38 const int kv_head_idx = head_idx / gqa_factor;
39 queries += head_idx * D + simd_lid * elem_per_thread;
40 keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
41 values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
42 out += head_idx * D + simd_gid * elem_per_thread;
45 for (
int i = 0; i < elem_per_thread; i++) {
46 q[i] =
static_cast<U
>(scale) * queries[i];
48 for (
int i = 0; i < elem_per_thread; i++) {
52 U max_score = -INFINITY;
56 for (
int i = simd_gid; i < N; i += BN) {
58 for (
int i = 0; i < elem_per_thread; i++) {
64 for (
int i = 0; i < elem_per_thread; i++) {
70 U new_max =
max(max_score, score);
71 U factor =
fast::exp(max_score - new_max);
75 sum_exp_score = sum_exp_score * factor + exp_score;
78 for (
int i = 0; i < elem_per_thread; i++) {
79 o[i] = o[i] * factor + exp_score * values[i];
86 threadgroup_barrier(mem_flags::mem_threadgroup);
92 max_scores[simd_gid] = max_score;
93 sum_exp_scores[simd_gid] = sum_exp_score;
95 threadgroup_barrier(mem_flags::mem_threadgroup);
96 max_score = max_scores[simd_lid];
98 U factor =
fast::exp(max_score - new_max);
99 sum_exp_score =
simd_sum(sum_exp_scores[simd_lid] * factor);
102 for (
int i = 0; i < elem_per_thread; i++) {
103 outputs[simd_lid * BD + simd_gid] = o[i];
104 threadgroup_barrier(mem_flags::mem_threadgroup);
105 o[i] =
simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
106 threadgroup_barrier(mem_flags::mem_threadgroup);
111 for (
int i = 0; i < elem_per_thread; i++) {
112 out[i] =
static_cast<T
>(o[i]);
void sdpa_vector(const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:8