MLX
 
Loading...
Searching...
No Matches
sdpa_vector.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#include <metal_simdgroup>
4
5using namespace metal;
6
7constant bool has_mask [[function_constant(20)]];
8
9template <typename T, int D>
10[[kernel]] void sdpa_vector(
11 const device T* queries [[buffer(0)]],
12 const device T* keys [[buffer(1)]],
13 const device T* values [[buffer(2)]],
14 device T* out [[buffer(3)]],
15 const constant int& gqa_factor,
16 const constant int& N,
17 const constant size_t& k_stride,
18 const constant size_t& v_stride,
19 const constant float& scale,
20 const device bool* mask [[function_constant(has_mask)]],
21 const constant int& mask_seq_stride [[function_constant(has_mask)]],
22 const constant int& mask_head_stride [[function_constant(has_mask)]],
23 uint3 tid [[threadgroup_position_in_grid]],
24 uint simd_gid [[simdgroup_index_in_threadgroup]],
25 uint simd_lid [[thread_index_in_simdgroup]]) {
26 constexpr int BN = 32;
27 constexpr int BD = 32;
28 constexpr int elem_per_thread = D / BD;
29 constexpr int stride = BN * D;
30
31 typedef float U;
32
33 thread U q[elem_per_thread];
34 thread U k[elem_per_thread];
35 thread U o[elem_per_thread];
36
37 threadgroup U outputs[BN * BD];
38 threadgroup U max_scores[BN];
39 threadgroup U sum_exp_scores[BN];
40
41 // Adjust positions
42 const int head_idx = tid.y;
43 const int kv_head_idx = head_idx / gqa_factor;
44 queries += head_idx * D + simd_lid * elem_per_thread;
45 keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
46 values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
47 if (has_mask) {
48 mask += head_idx * mask_head_stride + simd_gid * mask_seq_stride;
49 }
50 out += head_idx * D + simd_gid * elem_per_thread;
51
52 // Read the query and 0 the output accumulator
53 for (int i = 0; i < elem_per_thread; i++) {
54 q[i] = static_cast<U>(scale) * queries[i];
55 }
56 for (int i = 0; i < elem_per_thread; i++) {
57 o[i] = 0;
58 }
59
60 U max_score = -INFINITY;
61 U sum_exp_score = 0;
62
63 // For each key
64 for (int i = simd_gid; i < N; i += BN) {
65 if (!has_mask || mask[0]) {
66 // Read the key
67 for (int j = 0; j < elem_per_thread; j++) {
68 k[j] = keys[j];
69 }
70
71 // Compute the i-th score
72 U score = 0;
73 for (int j = 0; j < elem_per_thread; j++) {
74 score += q[j] * k[j];
75 }
76 score = simd_sum(score);
77
78 // Update the accumulators
79 U new_max = max(max_score, score);
80 U factor = fast::exp(max_score - new_max);
81 U exp_score = fast::exp(score - new_max);
82
83 max_score = new_max;
84 sum_exp_score = sum_exp_score * factor + exp_score;
85
86 // Update the output accumulator
87 for (int j = 0; j < elem_per_thread; j++) {
88 o[j] = o[j] * factor + exp_score * values[j];
89 }
90 }
91
92 // Move the pointers to the next kv
93 keys += stride;
94 values += stride;
95 if (has_mask) {
96 mask += BN * mask_seq_stride;
97 }
98 }
99
100 // Each thread has a partial part of the output so we need to combine them.
101
102 // First let's communicate the max and sum_exp
103 if (simd_lid == 0) {
104 max_scores[simd_gid] = max_score;
105 sum_exp_scores[simd_gid] = sum_exp_score;
106 }
107 threadgroup_barrier(mem_flags::mem_threadgroup);
108 max_score = max_scores[simd_lid];
109 U new_max = simd_max(max_score);
110 U factor = fast::exp(max_score - new_max);
111 sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
112
113 // Now we need to aggregate all the outputs
114 for (int i = 0; i < elem_per_thread; i++) {
115 outputs[simd_lid * BD + simd_gid] = o[i];
116 threadgroup_barrier(mem_flags::mem_threadgroup);
117 o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
118 threadgroup_barrier(mem_flags::mem_threadgroup);
119 }
120
121 // And write the output
122 if (simd_lid == 0) {
123 for (int i = 0; i < elem_per_thread; i++) {
124 out[i] = static_cast<T>(o[i]);
125 }
126 }
127}
128
129template <typename T, int D>
130[[kernel]] void sdpa_vector_2pass_1(
131 const device T* queries [[buffer(0)]],
132 const device T* keys [[buffer(1)]],
133 const device T* values [[buffer(2)]],
134 device float* out [[buffer(3)]],
135 device float* sums [[buffer(4)]],
136 device float* maxs [[buffer(5)]],
137 const constant int& gqa_factor,
138 const constant int& N,
139 const constant size_t& k_stride,
140 const constant size_t& v_stride,
141 const constant float& scale,
142 const device bool* mask [[function_constant(has_mask)]],
143 const constant int& mask_seq_stride [[function_constant(has_mask)]],
144 const constant int& mask_head_stride [[function_constant(has_mask)]],
145 uint3 tid [[threadgroup_position_in_grid]],
146 uint simd_gid [[simdgroup_index_in_threadgroup]],
147 uint simd_lid [[thread_index_in_simdgroup]]) {
148 constexpr int BN = 8;
149 constexpr int BD = 32;
150 constexpr int elem_per_thread = D / BD;
151 constexpr int stride = BN * D;
152 constexpr int blocks = 32;
153
154 typedef float U;
155
156 thread U q[elem_per_thread];
157 thread U k[elem_per_thread];
158 thread U o[elem_per_thread];
159
160 threadgroup U outputs[BN * BD];
161 threadgroup U max_scores[BN];
162 threadgroup U sum_exp_scores[BN];
163
164 // Adjust positions
165 const int block_idx = tid.z;
166 const int head_idx = tid.y;
167 const int kv_head_idx = head_idx / gqa_factor;
168 queries += head_idx * D + simd_lid * elem_per_thread;
169 keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
170 simd_lid * elem_per_thread;
171 values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
172 simd_lid * elem_per_thread;
173 out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
174 if (has_mask) {
175 mask += head_idx * mask_head_stride +
176 (block_idx * BN + simd_gid) * mask_seq_stride;
177 }
178 sums += head_idx * blocks + block_idx;
179 maxs += head_idx * blocks + block_idx;
180
181 // Read the query and 0 the output accumulator
182 for (int i = 0; i < elem_per_thread; i++) {
183 q[i] = static_cast<U>(scale) * queries[i];
184 }
185 for (int i = 0; i < elem_per_thread; i++) {
186 o[i] = 0;
187 }
188
189 U max_score = -1e9;
190 U sum_exp_score = 0;
191
192 // For each key
193 for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
194 if (!has_mask || mask[0]) {
195 // Read the key
196 for (int i = 0; i < elem_per_thread; i++) {
197 k[i] = keys[i];
198 }
199
200 // Compute the i-th score
201 U score = 0;
202 for (int i = 0; i < elem_per_thread; i++) {
203 score += q[i] * k[i];
204 }
205 score = simd_sum(score);
206
207 // Update the accumulators
208 U new_max = max(max_score, score);
209 U factor = fast::exp(max_score - new_max);
210 U exp_score = fast::exp(score - new_max);
211
212 max_score = new_max;
213 sum_exp_score = sum_exp_score * factor + exp_score;
214
215 // Update the output accumulator
216 for (int i = 0; i < elem_per_thread; i++) {
217 o[i] = o[i] * factor + exp_score * values[i];
218 }
219 }
220
221 // Move the pointers to the next kv
222 keys += blocks * stride;
223 values += blocks * stride;
224 if (has_mask) {
225 mask += BN * blocks * mask_seq_stride;
226 }
227 }
228
229 // Each thread has a partial part of the output so we need to combine them.
230
231 // First let's communicate the max and sum_exp
232 if (simd_lid == 0) {
233 max_scores[simd_gid] = max_score;
234 sum_exp_scores[simd_gid] = sum_exp_score;
235 }
236 threadgroup_barrier(mem_flags::mem_threadgroup);
237 max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
238 U new_max = simd_max(max_score);
239 U factor = fast::exp(max_score - new_max);
240 sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
241 sum_exp_score = simd_sum(sum_exp_score * factor);
242
243 // Write the sum and new max
244 if (simd_gid == 0) {
245 sums[0] = sum_exp_score;
246 maxs[0] = new_max;
247 }
248
249 // Now we need to aggregate all the outputs
250 for (int i = 0; i < elem_per_thread; i++) {
251 outputs[simd_lid * BN + simd_gid] =
252 o[i] * fast::exp(max_scores[simd_gid] - new_max);
253 threadgroup_barrier(mem_flags::mem_threadgroup);
254
255 // And write the output
256 if (simd_gid == 0) {
257 U output = outputs[simd_lid * BN];
258 for (int j = 1; j < BN; j++) {
259 output += outputs[simd_lid * BN + j];
260 }
261 out[i] = static_cast<T>(output);
262 }
263 threadgroup_barrier(mem_flags::mem_threadgroup);
264 }
265}
266
267template <typename T, int D>
268[[kernel]] void sdpa_vector_2pass_2(
269 const device float* partials [[buffer(0)]],
270 const device float* sums [[buffer(1)]],
271 const device float* maxs [[buffer(2)]],
272 device T* out [[buffer(3)]],
273 uint3 tid [[threadgroup_position_in_grid]],
274 uint simd_gid [[simdgroup_index_in_threadgroup]],
275 uint simd_lid [[thread_index_in_simdgroup]]) {
276 constexpr int BN = 32;
277 constexpr int BD = 32;
278 constexpr int elem_per_thread = D / BD;
279 constexpr int blocks = 32;
280
281 typedef float U;
282
283 thread U o[elem_per_thread];
284 threadgroup U outputs[BN * BD];
285
286 // Adjust positions
287 const int head_idx = tid.y;
288 partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
289 sums += head_idx * blocks;
290 maxs += head_idx * blocks;
291 out += head_idx * D + simd_gid * elem_per_thread;
292
293 // First everybody reads the max and sum_exp
294 U max_score = maxs[simd_lid];
295 U new_max = simd_max(max_score);
296 U factor = fast::exp(max_score - new_max);
297 U sum_exp_score = simd_sum(sums[simd_lid] * factor);
298
299 // Now read the block into registers and then use shared memory to transpose
300 // it
301 for (int i = 0; i < elem_per_thread; i++) {
302 o[i] = partials[i];
303 }
304 for (int i = 0; i < elem_per_thread; i++) {
305 outputs[simd_lid * BD + simd_gid] = o[i];
306 threadgroup_barrier(mem_flags::mem_threadgroup);
307 o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
308 threadgroup_barrier(mem_flags::mem_threadgroup);
309 }
310
311 // And write the output
312 if (simd_lid == 0) {
313 for (int i = 0; i < elem_per_thread; i++) {
314 out[i] = static_cast<T>(o[i]);
315 }
316 }
317}
METAL_FUNC bfloat16_t exp(bfloat16_t x)
Definition bf16_math.h:240
Definition bf16_math.h:226
METAL_FUNC bfloat16_t simd_max(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t simd_sum(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
void sdpa_vector_2pass_2(const device float *partials, const device float *sums, const device float *maxs, device T *out, uint3 tid, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:268
constant bool has_mask
Definition sdpa_vector.h:7
void sdpa_vector_2pass_1(const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_seq_stride, const constant int &mask_head_stride, uint3 tid, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:130
void sdpa_vector(const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, const device bool *mask, const constant int &mask_seq_stride, const constant int &mask_head_stride, uint3 tid, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:10