MLX
 
Loading...
Searching...
No Matches
sdpa_vector.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#include <metal_simdgroup>
4
5using namespace metal;
6
7constant bool has_mask [[function_constant(20)]];
8constant bool query_transposed [[function_constant(21)]];
9
10template <typename T, int D, int V = D>
11[[kernel]] void sdpa_vector(
12 const device T* queries [[buffer(0)]],
13 const device T* keys [[buffer(1)]],
14 const device T* values [[buffer(2)]],
15 device T* out [[buffer(3)]],
16 const constant int& gqa_factor,
17 const constant int& N,
18 const constant size_t& k_stride,
19 const constant size_t& v_stride,
20 const constant float& scale,
21 const device bool* mask [[function_constant(has_mask)]],
22 const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
23 const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
24 const constant int& mask_head_stride [[function_constant(has_mask)]],
25 uint3 tid [[threadgroup_position_in_grid]],
26 uint3 tpg [[threadgroups_per_grid]],
27 uint simd_gid [[simdgroup_index_in_threadgroup]],
28 uint simd_lid [[thread_index_in_simdgroup]]) {
29 constexpr int BN = 32;
30 constexpr int BD = 32;
31 constexpr int qk_per_thread = D / BD;
32 constexpr int v_per_thread = V / BD;
33 constexpr int inner_k_stride = BN * D;
34 constexpr int inner_v_stride = BN * V;
35
36 typedef float U;
37
38 thread U q[qk_per_thread];
39 thread U k[qk_per_thread];
40 thread U o[v_per_thread];
41
42 threadgroup U outputs[BN * BD];
43 threadgroup U max_scores[BN];
44 threadgroup U sum_exp_scores[BN];
45
46 // Adjust positions
47 const int head_idx = tid.x;
48 const int q_seq_idx = tid.y;
49 const int kv_head_idx = head_idx / gqa_factor;
50 const int o_offset = tpg.x * q_seq_idx + head_idx;
51 const int q_offset =
52 query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
53 queries += q_offset * D + simd_lid * qk_per_thread;
54 keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * qk_per_thread;
55 values += kv_head_idx * v_stride + simd_gid * V + simd_lid * v_per_thread;
56 if (has_mask) {
57 mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
58 q_seq_idx * mask_q_seq_stride;
59 }
60
61 out += o_offset * V + simd_gid * v_per_thread;
62
63 // Read the query and 0 the output accumulator
64 for (int i = 0; i < qk_per_thread; i++) {
65 q[i] = static_cast<U>(scale) * queries[i];
66 }
67 for (int i = 0; i < v_per_thread; i++) {
68 o[i] = 0;
69 }
70
71 U max_score = -INFINITY;
72 U sum_exp_score = 0;
73
74 // For each key
75 for (int i = simd_gid; i < N; i += BN) {
76 if (!has_mask || mask[0]) {
77 // Read the key
78 for (int j = 0; j < qk_per_thread; j++) {
79 k[j] = keys[j];
80 }
81
82 // Compute the i-th score
83 U score = 0;
84 for (int j = 0; j < qk_per_thread; j++) {
85 score += q[j] * k[j];
86 }
87 score = simd_sum(score);
88
89 // Update the accumulators
90 U new_max = max(max_score, score);
91 U factor = fast::exp(max_score - new_max);
92 U exp_score = fast::exp(score - new_max);
93
94 max_score = new_max;
95 sum_exp_score = sum_exp_score * factor + exp_score;
96
97 // Update the output accumulator
98 for (int j = 0; j < v_per_thread; j++) {
99 o[j] = o[j] * factor + exp_score * values[j];
100 }
101 }
102
103 // Move the pointers to the next kv
104 keys += inner_k_stride;
105 values += inner_v_stride;
106 if (has_mask) {
107 mask += BN * mask_kv_seq_stride;
108 }
109 }
110
111 // Each thread has a partial part of the output so we need to combine them.
112
113 // First let's communicate the max and sum_exp
114 if (simd_lid == 0) {
115 max_scores[simd_gid] = max_score;
116 sum_exp_scores[simd_gid] = sum_exp_score;
117 }
118 threadgroup_barrier(mem_flags::mem_threadgroup);
119 max_score = max_scores[simd_lid];
120 U new_max = simd_max(max_score);
121 U factor = fast::exp(max_score - new_max);
122 sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
123
124 // Now we need to aggregate all the outputs
125 for (int i = 0; i < v_per_thread; i++) {
126 outputs[simd_lid * BD + simd_gid] = o[i];
127 threadgroup_barrier(mem_flags::mem_threadgroup);
128 o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
129 threadgroup_barrier(mem_flags::mem_threadgroup);
130 }
131
132 // And write the output
133 if (simd_lid == 0) {
134 for (int i = 0; i < v_per_thread; i++) {
135 out[i] = static_cast<T>(o[i]);
136 }
137 }
138}
139
140template <typename T, int D, int V = D>
141[[kernel]] void sdpa_vector_2pass_1(
142 const device T* queries [[buffer(0)]],
143 const device T* keys [[buffer(1)]],
144 const device T* values [[buffer(2)]],
145 device float* out [[buffer(3)]],
146 device float* sums [[buffer(4)]],
147 device float* maxs [[buffer(5)]],
148 const constant int& gqa_factor,
149 const constant int& N,
150 const constant size_t& k_stride,
151 const constant size_t& v_stride,
152 const constant float& scale,
153 const device bool* mask [[function_constant(has_mask)]],
154 const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
155 const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
156 const constant int& mask_head_stride [[function_constant(has_mask)]],
157 uint3 tid [[threadgroup_position_in_grid]],
158 uint3 tpg [[threadgroups_per_grid]],
159 uint simd_gid [[simdgroup_index_in_threadgroup]],
160 uint simd_lid [[thread_index_in_simdgroup]]) {
161 constexpr int BN = 8;
162 constexpr int BD = 32;
163 constexpr int qk_per_thread = D / BD;
164 constexpr int v_per_thread = V / BD;
165 constexpr int inner_k_stride = BN * D;
166 constexpr int inner_v_stride = BN * V;
167 constexpr int blocks = 32;
168
169 typedef float U;
170
171 thread U q[qk_per_thread];
172 thread U k[qk_per_thread];
173 thread U o[v_per_thread];
174
175 threadgroup U outputs[BN * BD];
176 threadgroup U max_scores[BN];
177 threadgroup U sum_exp_scores[BN];
178
179 // Adjust positions
180 const int block_idx = tid.z;
181 const int head_idx = tid.x;
182 const int q_seq_idx = tid.y;
183 const int o_offset = tpg.x * q_seq_idx + head_idx;
184 const int q_offset =
185 query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
186 const int kv_head_idx = head_idx / gqa_factor;
187
188 queries += q_offset * D + simd_lid * qk_per_thread;
189 keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
190 simd_lid * qk_per_thread;
191 values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * V +
192 simd_lid * v_per_thread;
193 out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
194 if (has_mask) {
195 mask += head_idx * mask_head_stride +
196 (block_idx * BN + simd_gid) * mask_kv_seq_stride +
197 q_seq_idx * mask_q_seq_stride;
198 }
199 sums += o_offset * blocks + block_idx;
200 maxs += o_offset * blocks + block_idx;
201
202 // Read the query and 0 the output accumulator
203 for (int i = 0; i < qk_per_thread; i++) {
204 q[i] = static_cast<U>(scale) * queries[i];
205 }
206 for (int i = 0; i < v_per_thread; i++) {
207 o[i] = 0;
208 }
209
210 U max_score = -1e9;
211 U sum_exp_score = 0;
212
213 // For each key
214 for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
215 if (!has_mask || mask[0]) {
216 // Read the key
217 for (int i = 0; i < qk_per_thread; i++) {
218 k[i] = keys[i];
219 }
220
221 // Compute the i-th score
222 U score = 0;
223 for (int i = 0; i < qk_per_thread; i++) {
224 score += q[i] * k[i];
225 }
226 score = simd_sum(score);
227
228 // Update the accumulators
229 U new_max = max(max_score, score);
230 U factor = fast::exp(max_score - new_max);
231 U exp_score = fast::exp(score - new_max);
232
233 max_score = new_max;
234 sum_exp_score = sum_exp_score * factor + exp_score;
235
236 // Update the output accumulator
237 for (int i = 0; i < v_per_thread; i++) {
238 o[i] = o[i] * factor + exp_score * values[i];
239 }
240 }
241
242 // Move the pointers to the next kv
243 keys += blocks * inner_k_stride;
244 values += blocks * inner_v_stride;
245 if (has_mask) {
246 mask += BN * blocks * mask_kv_seq_stride;
247 }
248 }
249
250 // Each thread has a partial part of the output so we need to combine them.
251
252 // First let's communicate the max and sum_exp
253 if (simd_lid == 0) {
254 max_scores[simd_gid] = max_score;
255 sum_exp_scores[simd_gid] = sum_exp_score;
256 }
257 threadgroup_barrier(mem_flags::mem_threadgroup);
258 max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
259 U new_max = simd_max(max_score);
260 U factor = fast::exp(max_score - new_max);
261 sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
262 sum_exp_score = simd_sum(sum_exp_score * factor);
263
264 // Write the sum and new max
265 if (simd_gid == 0) {
266 sums[0] = sum_exp_score;
267 maxs[0] = new_max;
268 }
269
270 // Now we need to aggregate all the outputs
271 for (int i = 0; i < v_per_thread; i++) {
272 outputs[simd_lid * BN + simd_gid] =
273 o[i] * fast::exp(max_scores[simd_gid] - new_max);
274 threadgroup_barrier(mem_flags::mem_threadgroup);
275
276 // And write the output
277 if (simd_gid == 0) {
278 U output = outputs[simd_lid * BN];
279 for (int j = 1; j < BN; j++) {
280 output += outputs[simd_lid * BN + j];
281 }
282 out[i] = static_cast<T>(output);
283 }
284 threadgroup_barrier(mem_flags::mem_threadgroup);
285 }
286}
287
288template <typename T, int D>
289[[kernel]] void sdpa_vector_2pass_2(
290 const device float* partials [[buffer(0)]],
291 const device float* sums [[buffer(1)]],
292 const device float* maxs [[buffer(2)]],
293 device T* out [[buffer(3)]],
294 uint3 tid [[threadgroup_position_in_grid]],
295 uint3 tpg [[threadgroups_per_grid]],
296 uint simd_gid [[simdgroup_index_in_threadgroup]],
297 uint simd_lid [[thread_index_in_simdgroup]]) {
298 constexpr int BN = 32;
299 constexpr int BD = 32;
300 constexpr int elem_per_thread = D / BD;
301 constexpr int blocks = 32;
302
303 typedef float U;
304
305 thread U o[elem_per_thread];
306 threadgroup U outputs[BN * BD];
307
308 // Adjust positions
309 const int head_idx = tid.x;
310 const int q_seq_idx = tid.y;
311 const int n_heads = tpg.x;
312 const int q_offset = n_heads * q_seq_idx + head_idx;
313 partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
314 sums += q_offset * blocks;
315 maxs += q_offset * blocks;
316 out += q_offset * D + simd_gid * elem_per_thread;
317
318 // First everybody reads the max and sum_exp
319 U max_score = maxs[simd_lid];
320 U new_max = simd_max(max_score);
321 U factor = fast::exp(max_score - new_max);
322 U sum_exp_score = simd_sum(sums[simd_lid] * factor);
323
324 // Now read the block into registers and then use shared memory to transpose
325 // it
326 for (int i = 0; i < elem_per_thread; i++) {
327 o[i] = partials[i];
328 }
329 for (int i = 0; i < elem_per_thread; i++) {
330 outputs[simd_lid * BD + simd_gid] = o[i];
331 threadgroup_barrier(mem_flags::mem_threadgroup);
332 o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
333 threadgroup_barrier(mem_flags::mem_threadgroup);
334 }
335
336 // And write the output
337 if (simd_lid == 0) {
338 for (int i = 0; i < elem_per_thread; i++) {
339 out[i] = static_cast<T>(o[i]);
340 }
341 }
342}
METAL_FUNC bfloat16_t exp(bfloat16_t x)
Definition bf16_math.h:240
Definition bf16_math.h:226
METAL_FUNC bfloat16_t simd_max(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t simd_sum(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
constant bool query_transposed
Definition sdpa_vector.h:8
constant bool has_mask
Definition sdpa_vector.h:7
void sdpa_vector(const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_kv_seq_stride, const constant int &mask_q_seq_stride, const constant int &mask_head_stride, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:11
void sdpa_vector_2pass_2(const device float *partials, const device float *sums, const device float *maxs, device T *out, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:289
void sdpa_vector_2pass_1(const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_kv_seq_stride, const constant int &mask_q_seq_stride, const constant int &mask_head_stride, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:141