MLX
 
Loading...
Searching...
No Matches
sdpa_vector.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#include <metal_simdgroup>
4
5using namespace metal;
6
7constant bool has_mask [[function_constant(20)]];
8constant bool query_transposed [[function_constant(21)]];
9
10template <typename T, int D, int V = D>
11[[kernel]] void sdpa_vector(
12 const device T* queries [[buffer(0)]],
13 const device T* keys [[buffer(1)]],
14 const device T* values [[buffer(2)]],
15 device T* out [[buffer(3)]],
16 const constant int& gqa_factor,
17 const constant int& N,
18 const constant size_t& k_head_stride,
19 const constant size_t& k_seq_stride,
20 const constant size_t& v_head_stride,
21 const constant size_t& v_seq_stride,
22 const constant float& scale,
23 const device bool* mask [[function_constant(has_mask)]],
24 const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
25 const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
26 const constant int& mask_head_stride [[function_constant(has_mask)]],
27 uint3 tid [[threadgroup_position_in_grid]],
28 uint3 tpg [[threadgroups_per_grid]],
29 uint simd_gid [[simdgroup_index_in_threadgroup]],
30 uint simd_lid [[thread_index_in_simdgroup]]) {
31 constexpr int BN = 32;
32 constexpr int BD = 32;
33 constexpr int qk_per_thread = D / BD;
34 constexpr int v_per_thread = V / BD;
35 int inner_k_stride = BN * int(k_seq_stride);
36 int inner_v_stride = BN * int(v_seq_stride);
37
38 typedef float U;
39
40 thread U q[qk_per_thread];
41 thread U k[qk_per_thread];
42 thread U o[v_per_thread];
43
44 threadgroup U outputs[BN * BD];
45 threadgroup U max_scores[BN];
46 threadgroup U sum_exp_scores[BN];
47
48 // Adjust positions
49 const int head_idx = tid.x;
50 const int q_seq_idx = tid.y;
51 const int kv_head_idx = head_idx / gqa_factor;
52 const int o_offset = tpg.x * q_seq_idx + head_idx;
53 const int q_offset =
54 query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
55 queries += q_offset * D + simd_lid * qk_per_thread;
56 keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
57 simd_lid * qk_per_thread;
58 values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
59 simd_lid * v_per_thread;
60 if (has_mask) {
61 mask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride +
62 q_seq_idx * mask_q_seq_stride;
63 }
64
65 out += o_offset * V + simd_gid * v_per_thread;
66
67 // Read the query and 0 the output accumulator
68 for (int i = 0; i < qk_per_thread; i++) {
69 q[i] = static_cast<U>(scale) * queries[i];
70 }
71 for (int i = 0; i < v_per_thread; i++) {
72 o[i] = 0;
73 }
74
75 U max_score = -INFINITY;
76 U sum_exp_score = 0;
77
78 // For each key
79 for (int i = simd_gid; i < N; i += BN) {
80 if (!has_mask || mask[0]) {
81 // Read the key
82 for (int j = 0; j < qk_per_thread; j++) {
83 k[j] = keys[j];
84 }
85
86 // Compute the i-th score
87 U score = 0;
88 for (int j = 0; j < qk_per_thread; j++) {
89 score += q[j] * k[j];
90 }
91 score = simd_sum(score);
92
93 // Update the accumulators
94 U new_max = max(max_score, score);
95 U factor = fast::exp(max_score - new_max);
96 U exp_score = fast::exp(score - new_max);
97
98 max_score = new_max;
99 sum_exp_score = sum_exp_score * factor + exp_score;
100
101 // Update the output accumulator
102 for (int j = 0; j < v_per_thread; j++) {
103 o[j] = o[j] * factor + exp_score * values[j];
104 }
105 }
106
107 // Move the pointers to the next kv
108 keys += inner_k_stride;
109 values += inner_v_stride;
110 if (has_mask) {
111 mask += BN * mask_kv_seq_stride;
112 }
113 }
114
115 // Each thread has a partial part of the output so we need to combine them.
116
117 // First let's communicate the max and sum_exp
118 if (simd_lid == 0) {
119 max_scores[simd_gid] = max_score;
120 sum_exp_scores[simd_gid] = sum_exp_score;
121 }
122 threadgroup_barrier(mem_flags::mem_threadgroup);
123 max_score = max_scores[simd_lid];
124 U new_max = simd_max(max_score);
125 U factor = fast::exp(max_score - new_max);
126 sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
127
128 // Now we need to aggregate all the outputs
129 for (int i = 0; i < v_per_thread; i++) {
130 outputs[simd_lid * BD + simd_gid] = o[i];
131 threadgroup_barrier(mem_flags::mem_threadgroup);
132 o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
133 threadgroup_barrier(mem_flags::mem_threadgroup);
134 }
135
136 // And write the output
137 if (simd_lid == 0) {
138 for (int i = 0; i < v_per_thread; i++) {
139 out[i] = static_cast<T>(o[i]);
140 }
141 }
142}
143
144template <typename T, int D, int V = D>
145[[kernel]] void sdpa_vector_2pass_1(
146 const device T* queries [[buffer(0)]],
147 const device T* keys [[buffer(1)]],
148 const device T* values [[buffer(2)]],
149 device float* out [[buffer(3)]],
150 device float* sums [[buffer(4)]],
151 device float* maxs [[buffer(5)]],
152 const constant int& gqa_factor,
153 const constant int& N,
154 const constant size_t& k_head_stride,
155 const constant size_t& k_seq_stride,
156 const constant size_t& v_head_stride,
157 const constant size_t& v_seq_stride,
158 const constant float& scale,
159 const device bool* mask [[function_constant(has_mask)]],
160 const constant int& mask_kv_seq_stride [[function_constant(has_mask)]],
161 const constant int& mask_q_seq_stride [[function_constant(has_mask)]],
162 const constant int& mask_head_stride [[function_constant(has_mask)]],
163 uint3 tid [[threadgroup_position_in_grid]],
164 uint3 tpg [[threadgroups_per_grid]],
165 uint simd_gid [[simdgroup_index_in_threadgroup]],
166 uint simd_lid [[thread_index_in_simdgroup]]) {
167 constexpr int BN = 8;
168 constexpr int BD = 32;
169 constexpr int qk_per_thread = D / BD;
170 constexpr int v_per_thread = V / BD;
171 int inner_k_stride = BN * int(k_seq_stride);
172 int inner_v_stride = BN * int(v_seq_stride);
173 constexpr int blocks = 32;
174
175 typedef float U;
176
177 thread U q[qk_per_thread];
178 thread U k[qk_per_thread];
179 thread U o[v_per_thread];
180
181 threadgroup U outputs[BN * BD];
182 threadgroup U max_scores[BN];
183 threadgroup U sum_exp_scores[BN];
184
185 // Adjust positions
186 const int block_idx = tid.z;
187 const int head_idx = tid.x;
188 const int q_seq_idx = tid.y;
189 const int o_offset = tpg.x * q_seq_idx + head_idx;
190 const int q_offset =
191 query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx;
192 const int kv_head_idx = head_idx / gqa_factor;
193
194 queries += q_offset * D + simd_lid * qk_per_thread;
195 keys += kv_head_idx * k_head_stride +
196 (block_idx * BN + simd_gid) * k_seq_stride + simd_lid * qk_per_thread;
197 values += kv_head_idx * v_head_stride +
198 (block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
199 out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
200 if (has_mask) {
201 mask += head_idx * mask_head_stride +
202 (block_idx * BN + simd_gid) * mask_kv_seq_stride +
203 q_seq_idx * mask_q_seq_stride;
204 }
205 sums += o_offset * blocks + block_idx;
206 maxs += o_offset * blocks + block_idx;
207
208 // Read the query and 0 the output accumulator
209 for (int i = 0; i < qk_per_thread; i++) {
210 q[i] = static_cast<U>(scale) * queries[i];
211 }
212 for (int i = 0; i < v_per_thread; i++) {
213 o[i] = 0;
214 }
215
216 U max_score = -1e9;
217 U sum_exp_score = 0;
218
219 // For each key
220 for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
221 if (!has_mask || mask[0]) {
222 // Read the key
223 for (int i = 0; i < qk_per_thread; i++) {
224 k[i] = keys[i];
225 }
226
227 // Compute the i-th score
228 U score = 0;
229 for (int i = 0; i < qk_per_thread; i++) {
230 score += q[i] * k[i];
231 }
232 score = simd_sum(score);
233
234 // Update the accumulators
235 U new_max = max(max_score, score);
236 U factor = fast::exp(max_score - new_max);
237 U exp_score = fast::exp(score - new_max);
238
239 max_score = new_max;
240 sum_exp_score = sum_exp_score * factor + exp_score;
241
242 // Update the output accumulator
243 for (int i = 0; i < v_per_thread; i++) {
244 o[i] = o[i] * factor + exp_score * values[i];
245 }
246 }
247
248 // Move the pointers to the next kv
249 keys += blocks * inner_k_stride;
250 values += blocks * inner_v_stride;
251 if (has_mask) {
252 mask += BN * blocks * mask_kv_seq_stride;
253 }
254 }
255
256 // Each thread has a partial part of the output so we need to combine them.
257
258 // First let's communicate the max and sum_exp
259 if (simd_lid == 0) {
260 max_scores[simd_gid] = max_score;
261 sum_exp_scores[simd_gid] = sum_exp_score;
262 }
263 threadgroup_barrier(mem_flags::mem_threadgroup);
264 max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
265 U new_max = simd_max(max_score);
266 U factor = fast::exp(max_score - new_max);
267 sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
268 sum_exp_score = simd_sum(sum_exp_score * factor);
269
270 // Write the sum and new max
271 if (simd_gid == 0) {
272 sums[0] = sum_exp_score;
273 maxs[0] = new_max;
274 }
275
276 // Now we need to aggregate all the outputs
277 for (int i = 0; i < v_per_thread; i++) {
278 outputs[simd_lid * BN + simd_gid] =
279 o[i] * fast::exp(max_scores[simd_gid] - new_max);
280 threadgroup_barrier(mem_flags::mem_threadgroup);
281
282 // And write the output
283 if (simd_gid == 0) {
284 U output = outputs[simd_lid * BN];
285 for (int j = 1; j < BN; j++) {
286 output += outputs[simd_lid * BN + j];
287 }
288 out[i] = static_cast<T>(output);
289 }
290 threadgroup_barrier(mem_flags::mem_threadgroup);
291 }
292}
293
294template <typename T, int D>
295[[kernel]] void sdpa_vector_2pass_2(
296 const device float* partials [[buffer(0)]],
297 const device float* sums [[buffer(1)]],
298 const device float* maxs [[buffer(2)]],
299 device T* out [[buffer(3)]],
300 uint3 tid [[threadgroup_position_in_grid]],
301 uint3 tpg [[threadgroups_per_grid]],
302 uint simd_gid [[simdgroup_index_in_threadgroup]],
303 uint simd_lid [[thread_index_in_simdgroup]]) {
304 constexpr int BN = 32;
305 constexpr int BD = 32;
306 constexpr int elem_per_thread = D / BD;
307 constexpr int blocks = 32;
308
309 typedef float U;
310
311 thread U o[elem_per_thread];
312 threadgroup U outputs[BN * BD];
313
314 // Adjust positions
315 const int head_idx = tid.x;
316 const int q_seq_idx = tid.y;
317 const int n_heads = tpg.x;
318 const int q_offset = n_heads * q_seq_idx + head_idx;
319 partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
320 sums += q_offset * blocks;
321 maxs += q_offset * blocks;
322 out += q_offset * D + simd_gid * elem_per_thread;
323
324 // First everybody reads the max and sum_exp
325 U max_score = maxs[simd_lid];
326 U new_max = simd_max(max_score);
327 U factor = fast::exp(max_score - new_max);
328 U sum_exp_score = simd_sum(sums[simd_lid] * factor);
329
330 // Now read the block into registers and then use shared memory to transpose
331 // it
332 for (int i = 0; i < elem_per_thread; i++) {
333 o[i] = partials[i];
334 }
335 for (int i = 0; i < elem_per_thread; i++) {
336 outputs[simd_lid * BD + simd_gid] = o[i];
337 threadgroup_barrier(mem_flags::mem_threadgroup);
338 o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
339 threadgroup_barrier(mem_flags::mem_threadgroup);
340 }
341
342 // And write the output
343 if (simd_lid == 0) {
344 for (int i = 0; i < elem_per_thread; i++) {
345 out[i] = static_cast<T>(o[i]);
346 }
347 }
348}
METAL_FUNC bfloat16_t exp(bfloat16_t x)
Definition bf16_math.h:240
Definition bf16_math.h:226
METAL_FUNC bfloat16_t simd_max(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t simd_sum(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
constant bool query_transposed
Definition sdpa_vector.h:8
void sdpa_vector_2pass_1(const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &gqa_factor, const constant int &N, const constant size_t &k_head_stride, const constant size_t &k_seq_stride, const constant size_t &v_head_stride, const constant size_t &v_seq_stride, const constant float &scale, const device bool *mask, const constant int &mask_kv_seq_stride, const constant int &mask_q_seq_stride, const constant int &mask_head_stride, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:145
void sdpa_vector(const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_head_stride, const constant size_t &k_seq_stride, const constant size_t &v_head_stride, const constant size_t &v_seq_stride, const constant float &scale, const device bool *mask, const constant int &mask_kv_seq_stride, const constant int &mask_q_seq_stride, const constant int &mask_head_stride, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:11
constant bool has_mask
Definition sdpa_vector.h:7
void sdpa_vector_2pass_2(const device float *partials, const device float *sums, const device float *maxs, device T *out, uint3 tid, uint3 tpg, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:295