MLX
Loading...
Searching...
No Matches
sdpa_vector.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#include <metal_simdgroup>
4
5using namespace metal;
6
7template <typename T, int D>
8[[kernel]] void sdpa_vector(
9 const device T* queries [[buffer(0)]],
10 const device T* keys [[buffer(1)]],
11 const device T* values [[buffer(2)]],
12 device T* out [[buffer(3)]],
13 const constant int& gqa_factor,
14 const constant int& N,
15 const constant size_t& k_stride,
16 const constant size_t& v_stride,
17 const constant float& scale,
18 uint3 tid [[threadgroup_position_in_grid]],
19 uint simd_gid [[simdgroup_index_in_threadgroup]],
20 uint simd_lid [[thread_index_in_simdgroup]]) {
21 constexpr int BN = 32;
22 constexpr int BD = 32;
23 constexpr int elem_per_thread = D / BD;
24 constexpr int stride = BN * D;
25
26 typedef float U;
27
28 thread U q[elem_per_thread];
29 thread U k[elem_per_thread];
30 thread U o[elem_per_thread];
31
32 threadgroup U outputs[BN * BD];
33 threadgroup U max_scores[BN];
34 threadgroup U sum_exp_scores[BN];
35
36 // Adjust positions
37 const int head_idx = tid.y;
38 const int kv_head_idx = head_idx / gqa_factor;
39 queries += head_idx * D + simd_lid * elem_per_thread;
40 keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
41 values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
42 out += head_idx * D + simd_gid * elem_per_thread;
43
44 // Read the query and 0 the output accumulator
45 for (int i = 0; i < elem_per_thread; i++) {
46 q[i] = static_cast<U>(scale) * queries[i];
47 }
48 for (int i = 0; i < elem_per_thread; i++) {
49 o[i] = 0;
50 }
51
52 U max_score = -INFINITY;
53 U sum_exp_score = 0;
54
55 // For each key
56 for (int i = simd_gid; i < N; i += BN) {
57 // Read the key
58 for (int i = 0; i < elem_per_thread; i++) {
59 k[i] = keys[i];
60 }
61
62 // Compute the i-th score
63 U score = 0;
64 for (int i = 0; i < elem_per_thread; i++) {
65 score += q[i] * k[i];
66 }
67 score = simd_sum(score);
68
69 // Update the accumulators
70 U new_max = max(max_score, score);
71 U factor = fast::exp(max_score - new_max);
72 U exp_score = fast::exp(score - new_max);
73
74 max_score = new_max;
75 sum_exp_score = sum_exp_score * factor + exp_score;
76
77 // Update the output accumulator
78 for (int i = 0; i < elem_per_thread; i++) {
79 o[i] = o[i] * factor + exp_score * values[i];
80 }
81
82 // Move the pointers to the next kv
83 keys += stride;
84 values += stride;
85 }
86
87 // Each thread has a partial part of the output so we need to combine them.
88
89 // First let's communicate the max and sum_exp
90 if (simd_lid == 0) {
91 max_scores[simd_gid] = max_score;
92 sum_exp_scores[simd_gid] = sum_exp_score;
93 }
94 threadgroup_barrier(mem_flags::mem_threadgroup);
95 max_score = max_scores[simd_lid];
96 U new_max = simd_max(max_score);
97 U factor = fast::exp(max_score - new_max);
98 sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
99
100 // Now we need to aggregate all the outputs
101 for (int i = 0; i < elem_per_thread; i++) {
102 outputs[simd_lid * BD + simd_gid] = o[i];
103 threadgroup_barrier(mem_flags::mem_threadgroup);
104 o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
105 threadgroup_barrier(mem_flags::mem_threadgroup);
106 }
107
108 // And write the output
109 if (simd_lid == 0) {
110 for (int i = 0; i < elem_per_thread; i++) {
111 out[i] = static_cast<T>(o[i]);
112 }
113 }
114}
115
116template <typename T, int D>
117[[kernel]] void sdpa_vector_2pass_1(
118 const device T* queries [[buffer(0)]],
119 const device T* keys [[buffer(1)]],
120 const device T* values [[buffer(2)]],
121 device float* out [[buffer(3)]],
122 device float* sums [[buffer(4)]],
123 device float* maxs [[buffer(5)]],
124 const constant int& gqa_factor,
125 const constant int& N,
126 const constant size_t& k_stride,
127 const constant size_t& v_stride,
128 const constant float& scale,
129 uint3 tid [[threadgroup_position_in_grid]],
130 uint simd_gid [[simdgroup_index_in_threadgroup]],
131 uint simd_lid [[thread_index_in_simdgroup]]) {
132 constexpr int BN = 8;
133 constexpr int BD = 32;
134 constexpr int elem_per_thread = D / BD;
135 constexpr int stride = BN * D;
136 constexpr int blocks = 32;
137
138 typedef float U;
139
140 thread U q[elem_per_thread];
141 thread U k[elem_per_thread];
142 thread U o[elem_per_thread];
143
144 threadgroup U outputs[BN * BD];
145 threadgroup U max_scores[BN];
146 threadgroup U sum_exp_scores[BN];
147
148 // Adjust positions
149 const int block_idx = tid.z;
150 const int head_idx = tid.y;
151 const int kv_head_idx = head_idx / gqa_factor;
152 queries += head_idx * D + simd_lid * elem_per_thread;
153 keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
154 simd_lid * elem_per_thread;
155 values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
156 simd_lid * elem_per_thread;
157 out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
158 sums += head_idx * blocks + block_idx;
159 maxs += head_idx * blocks + block_idx;
160
161 // Read the query and 0 the output accumulator
162 for (int i = 0; i < elem_per_thread; i++) {
163 q[i] = static_cast<U>(scale) * queries[i];
164 }
165 for (int i = 0; i < elem_per_thread; i++) {
166 o[i] = 0;
167 }
168
169 U max_score = -1e9;
170 U sum_exp_score = 0;
171
172 // For each key
173 for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
174 // Read the key
175 for (int i = 0; i < elem_per_thread; i++) {
176 k[i] = keys[i];
177 }
178
179 // Compute the i-th score
180 U score = 0;
181 for (int i = 0; i < elem_per_thread; i++) {
182 score += q[i] * k[i];
183 }
184 score = simd_sum(score);
185
186 // Update the accumulators
187 U new_max = max(max_score, score);
188 U factor = fast::exp(max_score - new_max);
189 U exp_score = fast::exp(score - new_max);
190
191 max_score = new_max;
192 sum_exp_score = sum_exp_score * factor + exp_score;
193
194 // Update the output accumulator
195 for (int i = 0; i < elem_per_thread; i++) {
196 o[i] = o[i] * factor + exp_score * values[i];
197 }
198
199 // Move the pointers to the next kv
200 keys += blocks * stride;
201 values += blocks * stride;
202 }
203
204 // Each thread has a partial part of the output so we need to combine them.
205
206 // First let's communicate the max and sum_exp
207 if (simd_lid == 0) {
208 max_scores[simd_gid] = max_score;
209 sum_exp_scores[simd_gid] = sum_exp_score;
210 }
211 threadgroup_barrier(mem_flags::mem_threadgroup);
212 max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
213 U new_max = simd_max(max_score);
214 U factor = fast::exp(max_score - new_max);
215 sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
216 sum_exp_score = simd_sum(sum_exp_score * factor);
217
218 // Write the sum and new max
219 if (simd_gid == 0) {
220 sums[0] = sum_exp_score;
221 maxs[0] = new_max;
222 }
223
224 // Now we need to aggregate all the outputs
225 for (int i = 0; i < elem_per_thread; i++) {
226 outputs[simd_lid * BN + simd_gid] =
227 o[i] * fast::exp(max_scores[simd_gid] - new_max);
228 threadgroup_barrier(mem_flags::mem_threadgroup);
229
230 // And write the output
231 if (simd_gid == 0) {
232 U output = outputs[simd_lid * BN];
233 for (int j = 1; j < BN; j++) {
234 output += outputs[simd_lid * BN + j];
235 }
236 out[i] = static_cast<T>(output);
237 }
238 threadgroup_barrier(mem_flags::mem_threadgroup);
239 }
240}
241
242template <typename T, int D>
243[[kernel]] void sdpa_vector_2pass_2(
244 const device float* partials [[buffer(0)]],
245 const device float* sums [[buffer(1)]],
246 const device float* maxs [[buffer(2)]],
247 device T* out [[buffer(3)]],
248 uint3 tid [[threadgroup_position_in_grid]],
249 uint simd_gid [[simdgroup_index_in_threadgroup]],
250 uint simd_lid [[thread_index_in_simdgroup]]) {
251 constexpr int BN = 32;
252 constexpr int BD = 32;
253 constexpr int elem_per_thread = D / BD;
254 constexpr int blocks = 32;
255
256 typedef float U;
257
258 thread U o[elem_per_thread];
259 threadgroup U outputs[BN * BD];
260
261 // Adjust positions
262 const int head_idx = tid.y;
263 partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
264 sums += head_idx * blocks;
265 maxs += head_idx * blocks;
266 out += head_idx * D + simd_gid * elem_per_thread;
267
268 // First everybody reads the max and sum_exp
269 U max_score = maxs[simd_lid];
270 U new_max = simd_max(max_score);
271 U factor = fast::exp(max_score - new_max);
272 U sum_exp_score = simd_sum(sums[simd_lid] * factor);
273
274 // Now read the block into registers and then use shared memory to transpose
275 // it
276 for (int i = 0; i < elem_per_thread; i++) {
277 o[i] = partials[i];
278 }
279 for (int i = 0; i < elem_per_thread; i++) {
280 outputs[simd_lid * BD + simd_gid] = o[i];
281 threadgroup_barrier(mem_flags::mem_threadgroup);
282 o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
283 threadgroup_barrier(mem_flags::mem_threadgroup);
284 }
285
286 // And write the output
287 if (simd_lid == 0) {
288 for (int i = 0; i < elem_per_thread; i++) {
289 out[i] = static_cast<T>(o[i]);
290 }
291 }
292}
METAL_FUNC bfloat16_t exp(bfloat16_t x)
Definition bf16_math.h:240
Definition bf16_math.h:226
METAL_FUNC bfloat16_t simd_max(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t simd_sum(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
void sdpa_vector_2pass_2(const device float *partials, const device float *sums, const device float *maxs, device T *out, uint3 tid, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:243
void sdpa_vector(const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:8
void sdpa_vector_2pass_1(const device T *queries, const device T *keys, const device T *values, device float *out, device float *sums, device float *maxs, const constant int &gqa_factor, const constant int &N, const constant size_t &k_stride, const constant size_t &v_stride, const constant float &scale, uint3 tid, uint simd_gid, uint simd_lid)
Definition sdpa_vector.h:117