MLX
 
Loading...
Searching...
No Matches
sdpa_vector.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#include <metal_simdgroup>
4
5using namespace metal;
6
7constant bool has_mask [[function_constant(20)]];
8
9template <typename T, int D, int V = D>
10[[kernel]] void sdpa_vector(
11 const device T* queries [[buffer(0)]],
12 const device T* keys [[buffer(1)]],
13 const device T* values [[buffer(2)]],
14 device T* out [[buffer(3)]],
15 const constant int& gqa_factor,
16 const constant int& N,
17 const constant size_t& k_stride,
18 const constant size_t& v_stride,
19 const constant float& scale,
20 const device bool* mask [[function_constant(has_mask)]],
21 const constant int& mask_seq_stride [[function_constant(has_mask)]],
22 const constant int& mask_head_stride [[function_constant(has_mask)]],
23 uint3 tid [[threadgroup_position_in_grid]],
24 uint simd_gid [[simdgroup_index_in_threadgroup]],
25 uint simd_lid [[thread_index_in_simdgroup]]) {
26 constexpr int BN = 32;
27 constexpr int BD = 32;
28 constexpr int qk_per_thread = D / BD;
29 constexpr int v_per_thread = V / BD;
30 constexpr int inner_k_stride = BN * D;
31 constexpr int inner_v_stride = BN * V;
32
33 typedef float U;
34
35 thread U q[qk_per_thread];
36 thread U k[qk_per_thread];
37 thread U o[v_per_thread];
38
39 threadgroup U outputs[BN * BD];
40 threadgroup U max_scores[BN];
41 threadgroup U sum_exp_scores[BN];
42
43 // Adjust positions
44 const int head_idx = tid.y;
45 const int kv_head_idx = head_idx / gqa_factor;
46 queries += head_idx * D + simd_lid * qk_per_thread;
47 keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * qk_per_thread;
48 values += kv_head_idx * v_stride + simd_gid * V + simd_lid * v_per_thread;
49 if (has_mask) {
50 mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
51 }
52 out += head_idx * V + simd_gid * v_per_thread;
53
54 // Read the query and 0 the output accumulator
55 for (int i = 0; i < qk_per_thread; i++) {
56 q[i] = static_cast<U>(scale) * queries[i];
57 }
58 for (int i = 0; i < v_per_thread; i++) {
59 o[i] = 0;
60 }
61
62 U max_score = -INFINITY;
63 U sum_exp_score = 0;
64
65 // For each key
66 for (int i = simd_gid; i < N; i += BN) {
67 if (!has_mask || mask[0]) {
68 // Read the key
69 for (int j = 0; j < qk_per_thread; j++) {
70 k[j] = keys[j];
71 }
72
73 // Compute the i-th score
74 U score = 0;
75 for (int j = 0; j < qk_per_thread; j++) {
76 score += q[j] * k[j];
77 }
78 score = simd_sum(score);
79
80 // Update the accumulators
81 U new_max = max(max_score, score);
82 U factor = fast::exp(max_score - new_max);
83 U exp_score = fast::exp(score - new_max);
84
85 max_score = new_max;
86 sum_exp_score = sum_exp_score * factor + exp_score;
87
88 // Update the output accumulator
89 for (int j = 0; j < v_per_thread; j++) {
90 o[j] = o[j] * factor + exp_score * values[j];
91 }
92 }
93
94 // Move the pointers to the next kv
95 keys += inner_k_stride;
96 values += inner_v_stride;
97 if (has_mask) {
98 mask += BN * mask_seq_stride;
99 }
100 }
101
102 // Each thread has a partial part of the output so we need to combine them.
103
104 // First let's communicate the max and sum_exp
105 if (simd_lid == 0) {
106 max_scores[simd_gid] = max_score;
107 sum_exp_scores[simd_gid] = sum_exp_score;
108 }
109 threadgroup_barrier(mem_flags::mem_threadgroup);
110 max_score = max_scores[simd_lid];
111 U new_max = simd_max(max_score);
112 U factor = fast::exp(max_score - new_max);
113 sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
114
115 // Now we need to aggregate all the outputs
116 for (int i = 0; i < v_per_thread; i++) {
117 outputs[simd_lid * BD + simd_gid] = o[i];
118 threadgroup_barrier(mem_flags::mem_threadgroup);
119 o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
120 threadgroup_barrier(mem_flags::mem_threadgroup);
121 }
122
123 // And write the output
124 if (simd_lid == 0) {
125 for (int i = 0; i < v_per_thread; i++) {
126 out[i] = static_cast<T>(o[i]);
127 }
128 }
129}
130
131template <typename T, int D, int V = D>
132[[kernel]] void sdpa_vector_2pass_1(
133 const device T* queries [[buffer(0)]],
134 const device T* keys [[buffer(1)]],
135 const device T* values [[buffer(2)]],
136 device float* out [[buffer(3)]],
137 device float* sums [[buffer(4)]],
138 device float* maxs [[buffer(5)]],
139 const constant int& gqa_factor,
140 const constant int& N,
141 const constant size_t& k_stride,
142 const constant size_t& v_stride,
143 const constant float& scale,
144 const device bool* mask [[function_constant(has_mask)]],
145 const constant int& mask_seq_stride [[function_constant(has_mask)]],
146 const constant int& mask_head_stride [[function_constant(has_mask)]],
147 uint3 tid [[threadgroup_position_in_grid]],
148 uint simd_gid [[simdgroup_index_in_threadgroup]],
149 uint simd_lid [[thread_index_in_simdgroup]]) {
150 constexpr int BN = 8;
151 constexpr int BD = 32;
152 constexpr int qk_per_thread = D / BD;
153 constexpr int v_per_thread = V / BD;
154 constexpr int inner_k_stride = BN * D;
155 constexpr int inner_v_stride = BN * V;
156 constexpr int blocks = 32;
157
158 typedef float U;
159
160 thread U q[qk_per_thread];
161 thread U k[qk_per_thread];
162 thread U o[v_per_thread];
163
164 threadgroup U outputs[BN * BD];
165 threadgroup U max_scores[BN];
166 threadgroup U sum_exp_scores[BN];
167
168 // Adjust positions
169 const int block_idx = tid.z;
170 const int head_idx = tid.y;
171 const int kv_head_idx = head_idx / gqa_factor;
172 queries += head_idx * D + simd_lid * qk_per_thread;
173 keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
174 simd_lid * qk_per_thread;
175 values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * V +
176 simd_lid * v_per_thread;
177 out += head_idx * blocks * V + block_idx * V + simd_lid * v_per_thread;
178 if (has_mask) {
179 mask += head_idx * mask_head_stride +
180 (block_idx * BN + simd_gid) * mask_seq_stride;
181 }
182 sums += head_idx * blocks + block_idx;
183 maxs += head_idx * blocks + block_idx;
184
185 // Read the query and 0 the output accumulator
186 for (int i = 0; i < qk_per_thread; i++) {
187 q[i] = static_cast<U>(scale) * queries[i];
188 }
189 for (int i = 0; i < v_per_thread; i++) {
190 o[i] = 0;
191 }
192
193 U max_score = -1e9;
194 U sum_exp_score = 0;
195
196 // For each key
197 for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
198 if (!has_mask || mask[0]) {
199 // Read the key
200 for (int i = 0; i < qk_per_thread; i++) {
201 k[i] = keys[i];
202 }
203
204 // Compute the i-th score
205 U score = 0;
206 for (int i = 0; i < qk_per_thread; i++) {
207 score += q[i] * k[i];
208 }
209 score = simd_sum(score);
210
211 // Update the accumulators
212 U new_max = max(max_score, score);
213 U factor = fast::exp(max_score - new_max);
214 U exp_score = fast::exp(score - new_max);
215
216 max_score = new_max;
217 sum_exp_score = sum_exp_score * factor + exp_score;
218
219 // Update the output accumulator
220 for (int i = 0; i < v_per_thread; i++) {
221 o[i] = o[i] * factor + exp_score * values[i];
222 }
223 }
224
225 // Move the pointers to the next kv
226 keys += blocks * inner_k_stride;
227 values += blocks * inner_v_stride;
228 if (has_mask) {
229 mask += BN * blocks * mask_seq_stride;
230 }
231 }
232
233 // Each thread has a partial part of the output so we need to combine them.
234
235 // First let's communicate the max and sum_exp
236 if (simd_lid == 0) {
237 max_scores[simd_gid] = max_score;
238 sum_exp_scores[simd_gid] = sum_exp_score;
239 }
240 threadgroup_barrier(mem_flags::mem_threadgroup);
241 max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
242 U new_max = simd_max(max_score);
243 U factor = fast::exp(max_score - new_max);
244 sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
245 sum_exp_score = simd_sum(sum_exp_score * factor);
246
247 // Write the sum and new max
248 if (simd_gid == 0) {
249 sums[0] = sum_exp_score;
250 maxs[0] = new_max;
251 }
252
253 // Now we need to aggregate all the outputs
254 for (int i = 0; i < v_per_thread; i++) {
255 outputs[simd_lid * BN + simd_gid] =
256 o[i] * fast::exp(max_scores[simd_gid] - new_max);
257 threadgroup_barrier(mem_flags::mem_threadgroup);
258
259 // And write the output
260 if (simd_gid == 0) {
261 U output = outputs[simd_lid * BN];
262 for (int j = 1; j < BN; j++) {
263 output += outputs[simd_lid * BN + j];
264 }
265 out[i] = static_cast<T>(output);
266 }
267 threadgroup_barrier(mem_flags::mem_threadgroup);
268 }
269}
270
271template <typename T, int D>
272[[kernel]] void sdpa_vector_2pass_2(
273 const device float* partials [[buffer(0)]],
274 const device float* sums [[buffer(1)]],
275 const device float* maxs [[buffer(2)]],
276 device T* out [[buffer(3)]],
277 uint3 tid [[threadgroup_position_in_grid]],
278 uint simd_gid [[simdgroup_index_in_threadgroup]],
279 uint simd_lid [[thread_index_in_simdgroup]]) {
280 constexpr int BN = 32;
281 constexpr int BD = 32;
282 constexpr int elem_per_thread = D / BD;
283 constexpr int blocks = 32;
284
285 typedef float U;
286
287 thread U o[elem_per_thread];
288 threadgroup U outputs[BN * BD];
289
290 // Adjust positions
291 const int head_idx = tid.y;
292 partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
293 sums += head_idx * blocks;
294 maxs += head_idx * blocks;
295 out += head_idx * D + simd_gid * elem_per_thread;
296
297 // First everybody reads the max and sum_exp
298 U max_score = maxs[simd_lid];
299 U new_max = simd_max(max_score);
300 U factor = fast::exp(max_score - new_max);
301 U sum_exp_score = simd_sum(sums[simd_lid] * factor);
302
303 // Now read the block into registers and then use shared memory to transpose
304 // it
305 for (int i = 0; i < elem_per_thread; i++) {
306 o[i] = partials[i];
307 }
308 for (int i = 0; i < elem_per_thread; i++) {
309 outputs[simd_lid * BD + simd_gid] = o[i];
310 threadgroup_barrier(mem_flags::mem_threadgroup);
311 o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
312 threadgroup_barrier(mem_flags::mem_threadgroup);
313 }
314
315 // And write the output
316 if (simd_lid == 0) {
317 for (int i = 0; i < elem_per_thread; i++) {
318 out[i] = static_cast<T>(o[i]);
319 }
320 }
321}
METAL_FUNC bfloat16_t exp(bfloat16_t x)
Definition bf16_math.h:240
Definition bf16_math.h:226
METAL_FUNC bfloat16_t simd_max(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t simd_sum(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
void sdpa_vector_2pass_2(const device float *partials, const device float *sums, const device float *maxs, device T *out, uint3 tid, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:272
constant bool has_mask
Definition sdpa_vector.h:7
void sdpa_vector(const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_seq_stride, const constant int &mask_head_stride, uint3 tid, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:10
void sdpa_vector_2pass_1(const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_seq_stride, const constant int &mask_head_stride, uint3 tid, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:132