11 const device T* queries [[buffer(0)]],
12 const device T* keys [[buffer(1)]],
13 const device T* values [[buffer(2)]],
14 device T* out [[buffer(3)]],
15 const constant
int& gqa_factor,
16 const constant
int& N,
17 const constant
size_t& k_stride,
18 const constant
size_t& v_stride,
19 const constant
float& scale,
20 const device
bool* mask [[function_constant(
has_mask)]],
21 const constant
int& mask_seq_stride [[function_constant(
has_mask)]],
22 const constant
int& mask_head_stride [[function_constant(
has_mask)]],
23 uint3 tid [[threadgroup_position_in_grid]],
24 uint simd_gid [[simdgroup_index_in_threadgroup]],
25 uint simd_lid [[thread_index_in_simdgroup]]) {
26 constexpr int BN = 32;
27 constexpr int BD = 32;
28 constexpr int elem_per_thread = D / BD;
29 constexpr int stride = BN * D;
33 thread U q[elem_per_thread];
34 thread U k[elem_per_thread];
35 thread U o[elem_per_thread];
37 threadgroup U outputs[BN * BD];
38 threadgroup U max_scores[BN];
39 threadgroup U sum_exp_scores[BN];
42 const int head_idx = tid.y;
43 const int kv_head_idx = head_idx / gqa_factor;
44 queries += head_idx * D + simd_lid * elem_per_thread;
45 keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
46 values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
48 mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
50 out += head_idx * D + simd_gid * elem_per_thread;
53 for (
int i = 0; i < elem_per_thread; i++) {
54 q[i] =
static_cast<U
>(scale) * queries[i];
56 for (
int i = 0; i < elem_per_thread; i++) {
60 U max_score = -INFINITY;
64 for (
int i = simd_gid; i < N; i += BN) {
67 for (
int j = 0; j < elem_per_thread; j++) {
73 for (
int j = 0; j < elem_per_thread; j++) {
79 U new_max =
max(max_score, score);
80 U factor =
fast::exp(max_score - new_max);
84 sum_exp_score = sum_exp_score * factor + exp_score;
87 for (
int j = 0; j < elem_per_thread; j++) {
88 o[j] = o[j] * factor + exp_score * values[j];
96 mask += BN * mask_seq_stride;
104 max_scores[simd_gid] = max_score;
105 sum_exp_scores[simd_gid] = sum_exp_score;
107 threadgroup_barrier(mem_flags::mem_threadgroup);
108 max_score = max_scores[simd_lid];
110 U factor =
fast::exp(max_score - new_max);
111 sum_exp_score =
simd_sum(sum_exp_scores[simd_lid] * factor);
114 for (
int i = 0; i < elem_per_thread; i++) {
115 outputs[simd_lid * BD + simd_gid] = o[i];
116 threadgroup_barrier(mem_flags::mem_threadgroup);
117 o[i] =
simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
118 threadgroup_barrier(mem_flags::mem_threadgroup);
123 for (
int i = 0; i < elem_per_thread; i++) {
124 out[i] =
static_cast<T
>(o[i]);
131 const device T* queries [[buffer(0)]],
132 const device T* keys [[buffer(1)]],
133 const device T* values [[buffer(2)]],
134 device
float* out [[buffer(3)]],
135 device
float* sums [[buffer(4)]],
136 device
float* maxs [[buffer(5)]],
137 const constant
int& gqa_factor,
138 const constant
int& N,
139 const constant
size_t& k_stride,
140 const constant
size_t& v_stride,
141 const constant
float& scale,
142 const device
bool* mask [[function_constant(
has_mask)]],
143 const constant
int& mask_seq_stride [[function_constant(
has_mask)]],
144 const constant
int& mask_head_stride [[function_constant(
has_mask)]],
145 uint3 tid [[threadgroup_position_in_grid]],
146 uint simd_gid [[simdgroup_index_in_threadgroup]],
147 uint simd_lid [[thread_index_in_simdgroup]]) {
148 constexpr int BN = 8;
149 constexpr int BD = 32;
150 constexpr int elem_per_thread = D / BD;
151 constexpr int stride = BN * D;
152 constexpr int blocks = 32;
156 thread U q[elem_per_thread];
157 thread U k[elem_per_thread];
158 thread U o[elem_per_thread];
160 threadgroup U outputs[BN * BD];
161 threadgroup U max_scores[BN];
162 threadgroup U sum_exp_scores[BN];
165 const int block_idx = tid.z;
166 const int head_idx = tid.y;
167 const int kv_head_idx = head_idx / gqa_factor;
168 queries += head_idx * D + simd_lid * elem_per_thread;
169 keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
170 simd_lid * elem_per_thread;
171 values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
172 simd_lid * elem_per_thread;
173 out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
175 mask += head_idx * mask_head_stride +
176 (block_idx * BN + simd_gid) * mask_seq_stride;
178 sums += head_idx * blocks + block_idx;
179 maxs += head_idx * blocks + block_idx;
182 for (
int i = 0; i < elem_per_thread; i++) {
183 q[i] =
static_cast<U
>(scale) * queries[i];
185 for (
int i = 0; i < elem_per_thread; i++) {
193 for (
int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
196 for (
int i = 0; i < elem_per_thread; i++) {
202 for (
int i = 0; i < elem_per_thread; i++) {
203 score += q[i] * k[i];
208 U new_max =
max(max_score, score);
209 U factor =
fast::exp(max_score - new_max);
210 U exp_score =
fast::exp(score - new_max);
213 sum_exp_score = sum_exp_score * factor + exp_score;
216 for (
int i = 0; i < elem_per_thread; i++) {
217 o[i] = o[i] * factor + exp_score * values[i];
222 keys += blocks * stride;
223 values += blocks * stride;
225 mask += BN * blocks * mask_seq_stride;
233 max_scores[simd_gid] = max_score;
234 sum_exp_scores[simd_gid] = sum_exp_score;
236 threadgroup_barrier(mem_flags::mem_threadgroup);
237 max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
239 U factor =
fast::exp(max_score - new_max);
240 sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
241 sum_exp_score =
simd_sum(sum_exp_score * factor);
245 sums[0] = sum_exp_score;
250 for (
int i = 0; i < elem_per_thread; i++) {
251 outputs[simd_lid * BN + simd_gid] =
252 o[i] *
fast::exp(max_scores[simd_gid] - new_max);
253 threadgroup_barrier(mem_flags::mem_threadgroup);
257 U output = outputs[simd_lid * BN];
258 for (
int j = 1; j < BN; j++) {
259 output += outputs[simd_lid * BN + j];
261 out[i] =
static_cast<T
>(output);
263 threadgroup_barrier(mem_flags::mem_threadgroup);
269 const device
float* partials [[buffer(0)]],
270 const device
float* sums [[buffer(1)]],
271 const device
float* maxs [[buffer(2)]],
272 device T* out [[buffer(3)]],
273 uint3 tid [[threadgroup_position_in_grid]],
274 uint simd_gid [[simdgroup_index_in_threadgroup]],
275 uint simd_lid [[thread_index_in_simdgroup]]) {
276 constexpr int BN = 32;
277 constexpr int BD = 32;
278 constexpr int elem_per_thread = D / BD;
279 constexpr int blocks = 32;
283 thread U o[elem_per_thread];
284 threadgroup U outputs[BN * BD];
287 const int head_idx = tid.y;
288 partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
289 sums += head_idx * blocks;
290 maxs += head_idx * blocks;
291 out += head_idx * D + simd_gid * elem_per_thread;
294 U max_score = maxs[simd_lid];
296 U factor =
fast::exp(max_score - new_max);
297 U sum_exp_score =
simd_sum(sums[simd_lid] * factor);
301 for (
int i = 0; i < elem_per_thread; i++) {
304 for (
int i = 0; i < elem_per_thread; i++) {
305 outputs[simd_lid * BD + simd_gid] = o[i];
306 threadgroup_barrier(mem_flags::mem_threadgroup);
307 o[i] =
simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
308 threadgroup_barrier(mem_flags::mem_threadgroup);
313 for (
int i = 0; i < elem_per_thread; i++) {
314 out[i] =
static_cast<T
>(o[i]);
void sdpa_vector_2pass_2(const device float *partials, const device float *sums, const device float *maxs, device T *out, uint3 tid, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:268
void sdpa_vector_2pass_1(const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_seq_stride, const constant int &mask_head_stride, uint3 tid, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:130
void sdpa_vector(const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_seq_stride, const constant int &mask_head_stride, uint3 tid, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:10