MLX
 
Loading...
Searching...
No Matches
quantized.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#include <metal_simdgroup>
4#include <metal_stdlib>
5
6using namespace metal;
7
8#define MLX_MTL_CONST static constant constexpr const
9
12
13template <typename T, typename U, int values_per_thread, int bits>
14inline U load_vector(const device T* x, thread U* x_thread) {
15 static_assert(
16 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
17 "Template undefined for bits not in {2, 3, 4, 6, 8}");
18
19 U sum = 0;
20
21 if (bits == 2) {
22 for (int i = 0; i < values_per_thread; i += 4) {
23 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
24 x_thread[i] = x[i];
25 x_thread[i + 1] = x[i + 1] / 4.0f;
26 x_thread[i + 2] = x[i + 2] / 16.0f;
27 x_thread[i + 3] = x[i + 3] / 64.0f;
28 }
29 }
30
31 else if (bits == 3) {
32 for (int i = 0; i < values_per_thread; i += 8) {
33 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
34 x[i + 6] + x[i + 7];
35 x_thread[i] = x[i];
36 x_thread[i + 1] = x[i + 1] / 8.0f;
37 x_thread[i + 2] = x[i + 2] / 64.0f;
38 x_thread[i + 3] = x[i + 3] / 2.0f;
39 x_thread[i + 4] = x[i + 4] / 16.0f;
40 x_thread[i + 5] = x[i + 5] / 128.0f;
41 x_thread[i + 6] = x[i + 6] / 4.0f;
42 x_thread[i + 7] = x[i + 7] / 32.0f;
43 }
44 }
45
46 else if (bits == 4) {
47 for (int i = 0; i < values_per_thread; i += 4) {
48 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
49 x_thread[i] = x[i];
50 x_thread[i + 1] = x[i + 1] / 16.0f;
51 x_thread[i + 2] = x[i + 2] / 256.0f;
52 x_thread[i + 3] = x[i + 3] / 4096.0f;
53 }
54 }
55
56 else if (bits == 6) {
57 for (int i = 0; i < values_per_thread; i += 4) {
58 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
59 x_thread[i] = x[i];
60 x_thread[i + 1] = x[i + 1] / 64.0f;
61 x_thread[i + 2] = x[i + 2] / 16.0f;
62 x_thread[i + 3] = x[i + 3] / 4.0f;
63 }
64 }
65
66 else if (bits == 8) {
67 for (int i = 0; i < values_per_thread; i++) {
68 sum += x[i];
69 x_thread[i] = x[i];
70 }
71 }
72
73 return sum;
74}
75
76template <typename T, typename U, int values_per_thread, int bits>
77inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
78 static_assert(
79 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
80 "Template undefined for bits not in {2, 3, 4, 6, 8}");
81
82 U sum = 0;
83
84 if (bits == 2) {
85 for (int i = 0; i < N; i += 4) {
86 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
87 x_thread[i] = x[i];
88 x_thread[i + 1] = x[i + 1] / 4.0f;
89 x_thread[i + 2] = x[i + 2] / 16.0f;
90 x_thread[i + 3] = x[i + 3] / 64.0f;
91 }
92 }
93
94 else if (bits == 3) {
95 for (int i = 0; i < N; i += 8) {
96 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
97 x[i + 6] + x[i + 7];
98
99 x_thread[i] = x[i];
100 x_thread[i + 1] = x[i + 1] / 8.0f;
101 x_thread[i + 2] = x[i + 2] / 64.0f;
102 x_thread[i + 3] = x[i + 3] / 2.0f;
103 x_thread[i + 4] = x[i + 4] / 16.0f;
104 x_thread[i + 5] = x[i + 5] / 128.0f;
105 x_thread[i + 6] = x[i + 6] / 4.0f;
106 x_thread[i + 7] = x[i + 7] / 32.0f;
107 }
108 }
109
110 else if (bits == 4) {
111 for (int i = 0; i < N; i += 4) {
112 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
113 x_thread[i] = x[i];
114 x_thread[i + 1] = x[i + 1] / 16.0f;
115 x_thread[i + 2] = x[i + 2] / 256.0f;
116 x_thread[i + 3] = x[i + 3] / 4096.0f;
117 }
118 }
119
120 else if (bits == 6) {
121 for (int i = 0; i < N; i += 4) {
122 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
123 x_thread[i] = x[i];
124 x_thread[i + 1] = x[i + 1] / 64.0f;
125 x_thread[i + 2] = x[i + 2] / 16.0f;
126 x_thread[i + 3] = x[i + 3] / 4.0f;
127 }
128 }
129
130 else if (bits == 8) {
131 for (int i = 0; i < N; i++) {
132 sum += x[i];
133 x_thread[i] = x[i];
134 }
135 }
136
137 for (int i = N; i < values_per_thread; i++) {
138 x_thread[i] = 0;
139 }
140
141 return sum;
142}
143
144template <typename U, int values_per_thread, int bits>
145inline U qdot(
146 const device uint8_t* w,
147 const thread U* x_thread,
148 U scale,
149 U bias,
150 U sum) {
151 static_assert(
152 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
153 "Template undefined for bits not in {2, 3, 4, 6, 8}");
154
155 U accum = 0;
156
157 if (bits == 2) {
158 for (int i = 0; i < (values_per_thread / 4); i++) {
159 accum +=
160 (x_thread[4 * i] * (w[i] & 0x03) +
161 x_thread[4 * i + 1] * (w[i] & 0x0c) +
162 x_thread[4 * i + 2] * (w[i] & 0x30) +
163 x_thread[4 * i + 3] * (w[i] & 0xc0));
164 }
165 }
166
167 else if (bits == 3) {
168 for (int i = 0; i < (values_per_thread / 8); i++) {
169 x_thread += 8 * i;
170 w += 3 * i;
171
172 accum += (w[0] & 0x07) * x_thread[0];
173 accum += (w[0] & 0x38) * x_thread[1];
174 accum += (w[0] & 0xc0) * x_thread[2];
175 accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
176
177 accum += (w[1] & 0x0e) * x_thread[3];
178 accum += (w[1] & 0x70) * x_thread[4];
179 accum += (w[1] & 0x80) * x_thread[5];
180 accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
181
182 accum += (w[2] & 0x1c) * x_thread[6];
183 accum += (w[2] & 0xe0) * x_thread[7];
184 }
185 }
186
187 else if (bits == 4) {
188 const device uint16_t* ws = (const device uint16_t*)w;
189 for (int i = 0; i < (values_per_thread / 4); i++) {
190 accum +=
191 (x_thread[4 * i] * (ws[i] & 0x000f) +
192 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
193 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
194 x_thread[4 * i + 3] * (ws[i] & 0xf000));
195 }
196 }
197
198 else if (bits == 6) {
199 for (int i = 0; i < (values_per_thread / 4); i++) {
200 x_thread += 4 * i;
201 w += 3 * i;
202
203 accum += (w[0] & 0x3f) * x_thread[0];
204
205 accum += (w[0] & 0xc0) * x_thread[1];
206 accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
207
208 accum += (w[1] & 0xf0) * x_thread[2];
209 accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
210
211 accum += (w[2] & 0xfc) * x_thread[3];
212 }
213 }
214
215 else if (bits == 8) {
216 for (int i = 0; i < values_per_thread; i++) {
217 accum += x_thread[i] * w[i];
218 }
219 }
220
221 return scale * accum + sum * bias;
222}
223
224template <typename U, int values_per_thread, int bits>
225inline U qdot_safe(
226 const device uint8_t* w,
227 const thread U* x_thread,
228 U scale,
229 U bias,
230 U sum,
231 int N) {
232 static_assert(
233 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
234 "Template undefined for bits not in {2, 3, 4, 6, 8}");
235
236 U accum = 0;
237
238 if (bits == 2) {
239 for (int i = 0; i < (N / 4); i++) {
240 accum +=
241 (x_thread[4 * i] * (w[i] & 0x03) +
242 x_thread[4 * i + 1] * (w[i] & 0x0c) +
243 x_thread[4 * i + 2] * (w[i] & 0x30) +
244 x_thread[4 * i + 3] * (w[i] & 0xc0));
245 }
246 }
247
248 else if (bits == 3) {
249 for (int i = 0; i < (N / 8); i++) {
250 x_thread += 8 * i;
251 w += 3 * i;
252
253 accum += (w[0] & 0x07) * x_thread[0];
254 accum += (w[0] & 0x38) * x_thread[1];
255 accum += (w[0] & 0xc0) * x_thread[2];
256 accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
257
258 accum += (w[1] & 0x0e) * x_thread[3];
259 accum += (w[1] & 0x70) * x_thread[4];
260 accum += (w[1] & 0x80) * x_thread[5];
261 accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
262
263 accum += (w[2] & 0x1c) * x_thread[6];
264 accum += (w[2] & 0xe0) * x_thread[7];
265 }
266 }
267
268 else if (bits == 4) {
269 const device uint16_t* ws = (const device uint16_t*)w;
270 for (int i = 0; i < (N / 4); i++) {
271 accum +=
272 (x_thread[4 * i] * (ws[i] & 0x000f) +
273 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
274 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
275 x_thread[4 * i + 3] * (ws[i] & 0xf000));
276 }
277 }
278
279 else if (bits == 6) {
280 for (int i = 0; i < (N / 4); i++) {
281 x_thread += 4 * i;
282 w += 3 * i;
283
284 accum += (w[0] & 0x3f) * x_thread[0];
285
286 accum += (w[0] & 0xc0) * x_thread[1];
287 accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
288
289 accum += (w[1] & 0xf0) * x_thread[2];
290 accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
291
292 accum += (w[2] & 0xfc) * x_thread[3];
293 }
294 }
295
296 else if (bits == 8) {
297 for (int i = 0; i < N; i++) {
298 accum += x_thread[i] * w[i];
299 }
300 }
301
302 return scale * accum + sum * bias;
303}
304
305template <typename U, int values_per_thread, int bits>
306inline void
307qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
308 static_assert(
309 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
310 "Template undefined for bits not in {2, 3, 4, 6, 8}");
311
312 if (bits == 2) {
313 U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
314 for (int i = 0; i < (values_per_thread / 4); i++) {
315 result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
316 result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
317 result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
318 result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
319 }
320 }
321
322 else if (bits == 3) {
323 for (int i = 0; i < (values_per_thread / 8); i++) {
324 uint8_t w0 = w[3 * i];
325 uint8_t w1 = w[3 * i + 1];
326 uint8_t w2 = w[3 * i + 2];
327
328 result[8 * i] += x * ((w0 & 0x7) * scale + bias);
329 result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);
330 result[8 * i + 2] +=
331 x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);
332 result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);
333 result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);
334 result[8 * i + 5] +=
335 x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);
336 result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);
337 result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);
338 }
339 }
340
341 else if (bits == 4) {
342 U s[2] = {scale, scale / 16.0f};
343 for (int i = 0; i < (values_per_thread / 2); i++) {
344 result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
345 result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
346 }
347
348 } else if (bits == 6) {
349 for (int i = 0; i < (values_per_thread / 4); i++) {
350 uint8_t w0 = w[3 * i];
351 uint8_t w1 = w[3 * i + 1];
352 uint8_t w2 = w[3 * i + 2];
353
354 result[4 * i] += x * ((w0 & 0x3f) * scale + bias);
355 result[4 * i + 1] +=
356 x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);
357 result[4 * i + 2] +=
358 x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);
359 result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);
360 }
361 }
362
363 else if (bits == 8) {
364 for (int i = 0; i < values_per_thread; i++) {
365 result[i] += x * (scale * w[i] + bias);
366 }
367 }
368}
369
370template <typename U, int N, int bits>
371inline void
372dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
373 static_assert(
374 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
375 "Template undefined for bits not in {2, 3, 4, 6, 8}");
376
377 if (bits == 2) {
378 U s[4] = {
379 scale,
380 scale / static_cast<U>(4.0f),
381 scale / static_cast<U>(16.0f),
382 scale / static_cast<U>(64.0f)};
383 for (int i = 0; i < (N / 4); i++) {
384 w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
385 w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
386 w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
387 w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
388 }
389 }
390
391 else if (bits == 3) {
392 for (int i = 0; i < (N / 8); i++) {
393 w_local += 8 * i;
394 w += 3 * i;
395
396 w_local[0] = (w[0] & 0x7) * scale + bias;
397 w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;
398 w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
399 w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;
400 w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;
401 w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
402 w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
403 w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
404 }
405 }
406
407 else if (bits == 4) {
408 U s[2] = {scale, scale / static_cast<U>(16.0f)};
409 for (int i = 0; i < (N / 2); i++) {
410 w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
411 w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
412 }
413 }
414
415 else if (bits == 6) {
416 for (int i = 0; i < (N / 4); i++) {
417 w_local += 4 * i;
418 w += 3 * i;
419
420 w_local[0] = (w[0] & 0x3f) * scale + bias;
421 w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
422 w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
423 w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
424 }
425 }
426
427 else if (bits == 8) {
428 for (int i = 0; i < N; i++) {
429 w_local[i] = scale * w[i] + bias;
430 }
431 }
432}
433
434template <
435 typename T,
436 short BROWS,
437 short BCOLS,
438 short dst_ld,
439 short reduction_dim,
440 short tgp_size,
441 short group_size,
442 short bits>
444 static_assert(
445 BCOLS <= group_size,
446 "The group size should be larger than the columns");
447 static_assert(
448 group_size % BCOLS == 0,
449 "The group size should be divisible by the columns");
450 static_assert(
451 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
452 "Template undefined for bits not in {2, 3, 4, 6, 8}");
453
454 MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
455 MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
458 (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
459 MLX_MTL_CONST short group_steps = group_size / BCOLS;
460
461 const int src_ld;
462 const int tile_stride;
464 const int group_stride;
465
466 const short thread_idx;
467 const short bi;
468 const short bj;
469
470 threadgroup T* dst;
471 const device uint8_t* src;
472 const device T* scales;
473 const device T* biases;
474
476 const device uint8_t* src_,
477 const device T* scales_,
478 const device T* biases_,
479 const int src_ld_,
480 threadgroup T* dst_,
481 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
482 ushort simd_lane_id [[thread_index_in_simdgroup]])
483 : src_ld(src_ld_),
485 reduction_dim ? BCOLS_PACKED * bytes_per_pack
486 : BROWS * src_ld * bytes_per_pack / pack_factor),
488 group_stride(BROWS * src_ld / group_size),
489 thread_idx(simd_group_id * 32 + simd_lane_id),
492 dst(dst_ + bi * dst_ld + bj * pack_factor),
493 src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
495 scales(scales_ + bi * src_ld / group_size),
496 biases(biases_ + bi * src_ld / group_size) {}
497
498 void load_unsafe() const {
499 if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
500 return;
501 }
502
503 T scale = *scales;
504 T bias = *biases;
505 for (int i = 0; i < n_reads; i++) {
507 src + i * bytes_per_pack, scale, bias, dst + i * pack_factor);
508 }
509 }
510
511 void load_safe(short2 src_tile_dim) const {
512 if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
513 return;
514 }
515
516 if (reduction_dim == 1 && bi >= src_tile_dim.y) {
517 for (int i = 0; i < n_reads * pack_factor; i++) {
518 dst[i] = T(0);
519 }
520 return;
521 }
522
523 if (reduction_dim == 0 && bi >= src_tile_dim.x) {
524 for (int i = 0; i < n_reads * pack_factor; i++) {
525 dst[i] = T(0);
526 }
527 return;
528 }
529
530 T scale = *scales;
531 T bias = *biases;
532 for (int i = 0; i < n_reads; i++) {
534 (device uint8_t*)(src + i * bytes_per_pack),
535 scale,
536 bias,
537 dst + i * pack_factor);
538 }
539 }
540
541 void next() {
542 src += tile_stride;
543 if (reduction_dim == 1) {
544 if (group_steps > 1) {
547 group_step_cnt = 0;
548 scales++;
549 biases++;
550 }
551 } else {
552 scales++;
553 biases++;
554 }
555 } else {
558 }
559 }
560};
561
562template <typename T, int group_size, int bits, int D>
563METAL_FUNC void qmv_quad_impl(
564 const device uint32_t* w,
565 const device T* scales,
566 const device T* biases,
567 const device T* x,
568 device T* y,
569 constant int& in_vec_size,
570 const constant int& out_vec_size,
571 uint3 tid [[threadgroup_position_in_grid]],
572 uint quad_gid [[quadgroup_index_in_threadgroup]],
573 uint quad_lid [[thread_index_in_quadgroup]]) {
574 constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
575 constexpr int pack_factor = 32 / bits;
576 constexpr int values_per_thread = D / QUAD_SIZE;
577 constexpr int packs_per_thread = values_per_thread / pack_factor;
578 constexpr int scale_step_per_thread = group_size / values_per_thread;
579 constexpr int results_per_quadgroup = 8;
580
581 typedef float U;
582
583 thread U x_thread[values_per_thread];
584 thread U result[results_per_quadgroup] = {0};
585
586 // Adjust positions
587 const int in_vec_size_w = in_vec_size / pack_factor;
588 const int in_vec_size_g = in_vec_size / group_size;
589 const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
590
591 w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
592 scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
593 biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
594 x += tid.y * in_vec_size + quad_lid * values_per_thread;
595 y += tid.y * out_vec_size + out_row;
596
598
599 for (int row = 0; row < results_per_quadgroup; row++) {
600 auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
601 const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
602 const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
603
604 U s = sl[0];
605 U b = bl[0];
606 if (row * quads_per_simd + out_row < out_vec_size) {
607 result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
608 }
609 }
610
611 for (int row = 0; row < results_per_quadgroup; row++) {
612 result[row] = quad_sum(result[row]);
613 if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
614 y[row * quads_per_simd] = static_cast<T>(result[row]);
615 }
616 }
617}
618
619template <typename T, int group_size, int bits>
620METAL_FUNC void qmv_fast_impl(
621 const device uint32_t* w,
622 const device T* scales,
623 const device T* biases,
624 const device T* x,
625 device T* y,
626 const constant int& in_vec_size,
627 const constant int& out_vec_size,
628 uint3 tid [[threadgroup_position_in_grid]],
629 uint simd_gid [[simdgroup_index_in_threadgroup]],
630 uint simd_lid [[thread_index_in_simdgroup]]) {
631 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
632 constexpr int packs_per_thread = bits == 2 ? 1 : 2;
633 constexpr int num_simdgroups = 2;
634 constexpr int results_per_simdgroup = 4;
635 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
636 constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
637 constexpr int values_per_thread = pack_factor * packs_per_thread;
638 constexpr int block_size = values_per_thread * SIMD_SIZE;
639 constexpr int scale_step_per_thread = group_size / values_per_thread;
640
641 const device uint8_t* ws = (const device uint8_t*)w;
642
643 typedef float U;
644
645 thread U x_thread[values_per_thread];
646 thread U result[results_per_simdgroup] = {0};
647
648 // Adjust positions
649 const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
650 const int in_vec_size_g = in_vec_size / group_size;
651 const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
652 simd_gid * results_per_simdgroup;
653
654 ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
655 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
656 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
657 x += tid.x * in_vec_size + simd_lid * values_per_thread;
658 y += tid.x * out_vec_size + out_row;
659
660 for (int k = 0; k < in_vec_size; k += block_size) {
662
663 for (int row = 0; row < results_per_simdgroup; row++) {
664 auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
665 const device T* sl = scales + row * in_vec_size_g;
666 const device T* bl = biases + row * in_vec_size_g;
667
668 U s = sl[0];
669 U b = bl[0];
670 result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
671 }
672
673 ws += block_size * bytes_per_pack / pack_factor;
674 scales += block_size / group_size;
675 biases += block_size / group_size;
676 x += block_size;
677 }
678
679 for (int row = 0; row < results_per_simdgroup; row++) {
680 result[row] = simd_sum(result[row]);
681 if (simd_lid == 0) {
682 y[row] = static_cast<T>(result[row]);
683 }
684 }
685}
686
687template <typename T, int group_size, int bits>
688METAL_FUNC void qmv_impl(
689 const device uint32_t* w,
690 const device T* scales,
691 const device T* biases,
692 const device T* x,
693 device T* y,
694 const constant int& in_vec_size,
695 const constant int& out_vec_size,
696 uint3 tid [[threadgroup_position_in_grid]],
697 uint simd_gid [[simdgroup_index_in_threadgroup]],
698 uint simd_lid [[thread_index_in_simdgroup]]) {
699 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
700 constexpr int num_simdgroups = 2;
701 constexpr int results_per_simdgroup = 4;
702 constexpr int packs_per_thread = 1;
703 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
704 constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
705 constexpr int values_per_thread = pack_factor * packs_per_thread;
706 constexpr int block_size = values_per_thread * SIMD_SIZE;
707 constexpr int scale_step_per_thread = group_size / values_per_thread;
708
709 const device uint8_t* ws = (const device uint8_t*)w;
710
711 typedef float U;
712
713 thread U x_thread[values_per_thread];
714 thread U result[results_per_simdgroup] = {0};
715
716 // Adjust positions
717 const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
718 const int in_vec_size_g = in_vec_size / group_size;
719 const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
720 simd_gid * results_per_simdgroup;
721 const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
722
723 if (out_row >= out_vec_size) {
724 return;
725 }
726
727 // In this case we need to properly guard all our reads because there isn't
728 // even 1 tile in the matrix
729 if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
730 ws +=
731 out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
732 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
733 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
734 x += tid.x * in_vec_size + simd_lid * values_per_thread;
735 y += tid.x * out_vec_size + out_row;
736
737 int k = 0;
738 for (; k < in_vec_size - block_size; k += block_size) {
740
741 for (int row = 0; out_row + row < out_vec_size; row++) {
742 auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
743 const device T* sl = scales + row * in_vec_size_g;
744 const device T* bl = biases + row * in_vec_size_g;
745
746 U s = sl[0];
747 U b = bl[0];
748 result[row] +=
749 qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
750 }
751
752 ws += block_size * bytes_per_pack / pack_factor;
753 scales += block_size / group_size;
754 biases += block_size / group_size;
755 x += block_size;
756 }
757 const int remaining = clamp(
758 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
759 0,
760 values_per_thread);
761 if (remaining > 0) {
763 x, x_thread, remaining);
764
765 for (int row = 0; out_row + row < out_vec_size; row++) {
766 auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
767 const device T* sl = scales + row * in_vec_size_g;
768 const device T* bl = biases + row * in_vec_size_g;
769
770 U s = sl[0];
771 U b = bl[0];
772 result[row] +=
773 qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
774 }
775 }
776
777 for (int row = 0; out_row + row < out_vec_size; row++) {
778 result[row] = simd_sum(result[row]);
779 if (simd_lid == 0) {
780 y[row] = static_cast<T>(result[row]);
781 }
782 }
783 }
784
785 // In this case the last tile is moved back to redo some output values
786 else {
787 ws += used_out_row * in_vec_size_w +
788 simd_lid * packs_per_thread * bytes_per_pack;
789 scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
790 biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
791 x += tid.x * in_vec_size + simd_lid * values_per_thread;
792 y += tid.x * out_vec_size + used_out_row;
793
794 int k = 0;
795 for (; k < in_vec_size - block_size; k += block_size) {
797
798 for (int row = 0; row < results_per_simdgroup; row++) {
799 auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
800 const device T* sl = scales + row * in_vec_size_g;
801 const device T* bl = biases + row * in_vec_size_g;
802
803 U s = sl[0];
804 U b = bl[0];
805 result[row] +=
806 qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
807 }
808
809 ws += block_size * bytes_per_pack / pack_factor;
810 scales += block_size / group_size;
811 biases += block_size / group_size;
812 x += block_size;
813 }
814 const int remaining = clamp(
815 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
816 0,
817 values_per_thread);
818 if (remaining > 0) {
820 x, x_thread, remaining);
821
822 for (int row = 0; row < results_per_simdgroup; row++) {
823 auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
824 const device T* sl = scales + row * in_vec_size_g;
825 const device T* bl = biases + row * in_vec_size_g;
826
827 U s = sl[0];
828 U b = bl[0];
830 wl, x_thread, s, b, sum, remaining);
831 }
832 }
833 for (int row = 0; row < results_per_simdgroup; row++) {
834 result[row] = simd_sum(result[row]);
835 if (simd_lid == 0) {
836 y[row] = static_cast<T>(result[row]);
837 }
838 }
839 }
840}
841
842template <typename T, const int group_size, const int bits>
843METAL_FUNC void qvm_impl(
844 const device uint32_t* w,
845 const device T* scales,
846 const device T* biases,
847 const device T* x,
848 device T* y,
849 const int in_vec_size,
850 const int out_vec_size,
851 uint3 tid [[threadgroup_position_in_grid]],
852 uint simd_gid [[simdgroup_index_in_threadgroup]],
853 uint simd_lid [[thread_index_in_simdgroup]]) {
854 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
855 constexpr int num_simdgroups = 2;
856 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
857 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
858 constexpr int tn = 32 / pack_factor;
859 constexpr int block_size = SIMD_SIZE;
860
861 using W_T =
863 const device W_T* ws = (const device W_T*)w;
864
865 typedef float U;
866 typedef struct {
867 W_T wi[tn * bytes_per_pack];
868 } vec_w;
869
870 thread vec_w w_local;
871 thread U result[tn * pack_factor] = {0};
872 thread U scale = 1;
873 thread U bias = 0;
874 thread U x_local = 0;
875
876 // Adjust positions
877 const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
878 const int out_vec_size_g = out_vec_size / group_size;
879 int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid);
880 ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;
881 scales += out_col / group_size + simd_lid * out_vec_size_g;
882 biases += out_col / group_size + simd_lid * out_vec_size_g;
883 x += tid.x * in_vec_size + simd_lid;
884 y += tid.x * out_vec_size + out_col;
885
886 if (out_col >= out_vec_size) {
887 return;
888 }
889
890 // Loop over in_vec in blocks of block_size
891 int remaining = in_vec_size % block_size;
892 if (remaining == 0) {
893 for (int i = 0; i < in_vec_size; i += block_size) {
894 x_local = *x;
895 scale = *scales;
896 bias = *biases;
897 w_local = *((device vec_w*)ws);
899 (thread uint8_t*)&w_local, x_local, scale, bias, result);
900
901 x += block_size;
902 scales += block_size * out_vec_size_g;
903 biases += block_size * out_vec_size_g;
904 ws += block_size * out_vec_size_w;
905 }
906 } else {
907 for (int i = block_size; i < in_vec_size; i += block_size) {
908 x_local = *x;
909 scale = *scales;
910 bias = *biases;
911 w_local = *((device vec_w*)ws);
912
914 (thread uint8_t*)&w_local, x_local, scale, bias, result);
915
916 x += block_size;
917 scales += block_size * out_vec_size_g;
918 biases += block_size * out_vec_size_g;
919 ws += block_size * out_vec_size_w;
920 }
921 if (static_cast<int>(simd_lid) < remaining) {
922 x_local = *x;
923 scale = *scales;
924 bias = *biases;
925 w_local = *((device vec_w*)ws);
926 } else {
927 x_local = 0;
928 scale = 0;
929 bias = 0;
930 }
932 (thread uint8_t*)&w_local, x_local, scale, bias, result);
933 }
934
935// Accumulate in the simdgroup
936#pragma clang loop unroll(full)
937 for (int k = 0; k < tn * pack_factor; k++) {
938 result[k] = simd_sum(result[k]);
939 }
940
941 // Store the result
942 if (simd_lid == 0) {
943#pragma clang loop unroll(full)
944 for (int k = 0; k < tn * pack_factor; k++) {
945 y[k] = static_cast<T>(result[k]);
946 }
947 }
948}
949
950template <
951 typename T,
952 const int group_size,
953 const int bits,
954 const bool aligned_N,
955 const int BM = 32,
956 const int BK = 32,
957 const int BN = 32>
958METAL_FUNC void qmm_t_impl(
959 const device uint32_t* w,
960 const device T* scales,
961 const device T* biases,
962 const device T* x,
963 device T* y,
964 threadgroup T* Xs,
965 threadgroup T* Ws,
966 const constant int& K,
967 const constant int& N,
968 const constant int& M,
969 uint3 tid [[threadgroup_position_in_grid]],
970 uint lid [[thread_index_in_threadgroup]],
971 uint simd_gid [[simdgroup_index_in_threadgroup]],
972 uint simd_lid [[thread_index_in_simdgroup]]) {
973 static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
974 static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
975
976 (void)lid;
977
978 constexpr int WM = 2;
979 constexpr int WN = 2;
980 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
981 constexpr int BK_padded = (BK + 16 / sizeof(T));
982 constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
983
984 // Instantiate the appropriate BlockMMA and Loader
985 using mma_t = mlx::steel::
986 BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
987 using loader_x_t =
989 using loader_w_t = QuantizedBlockLoader<
990 T,
991 BN,
992 BK,
993 BK_padded,
994 1,
995 WM * WN * SIMD_SIZE,
996 group_size,
997 bits>;
998
999 // Set the block
1000 const int K_w = K * bytes_per_pack / pack_factor;
1001 const int K_g = K / group_size;
1002 const int y_row = tid.y * BM;
1003 const int y_col = tid.x * BN;
1004
1005 auto wl = (const device uint8_t*)w;
1006
1007 x += y_row * K;
1008 wl += y_col * K_w;
1009 scales += y_col * K_g;
1010 biases += y_col * K_g;
1011 y += y_row * N + y_col;
1012
1013 // Make the x loader and mma operation
1014 const short num_els = min(BM, M - y_row);
1015 const short num_outs = min(BN, N - y_col);
1016 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
1017 loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
1018 mma_t mma_op(simd_gid, simd_lid);
1019
1020 if (num_els < BM) {
1021 if (!aligned_N && num_outs < BN) {
1022 for (int k = 0; k < K; k += BK) {
1023 threadgroup_barrier(mem_flags::mem_threadgroup);
1024 loader_x.load_safe(short2(BK, num_els));
1025 loader_w.load_safe(short2(BK, num_outs));
1026 threadgroup_barrier(mem_flags::mem_threadgroup);
1027 mma_op.mma(Xs, Ws);
1028 loader_x.next();
1029 loader_w.next();
1030 }
1031 } else {
1032 for (int k = 0; k < K; k += BK) {
1033 threadgroup_barrier(mem_flags::mem_threadgroup);
1034 loader_x.load_safe(short2(BK, num_els));
1035 loader_w.load_unsafe();
1036 threadgroup_barrier(mem_flags::mem_threadgroup);
1037 mma_op.mma(Xs, Ws);
1038 loader_x.next();
1039 loader_w.next();
1040 }
1041 }
1042 } else {
1043 if (!aligned_N && num_outs < BN) {
1044 for (int k = 0; k < K; k += BK) {
1045 threadgroup_barrier(mem_flags::mem_threadgroup);
1046 loader_x.load_unsafe();
1047 loader_w.load_safe(short2(BK, num_outs));
1048 threadgroup_barrier(mem_flags::mem_threadgroup);
1049 mma_op.mma(Xs, Ws);
1050 loader_x.next();
1051 loader_w.next();
1052 }
1053 } else {
1054 for (int k = 0; k < K; k += BK) {
1055 threadgroup_barrier(mem_flags::mem_threadgroup);
1056 loader_x.load_unsafe();
1057 loader_w.load_unsafe();
1058 threadgroup_barrier(mem_flags::mem_threadgroup);
1059
1060 mma_op.mma(Xs, Ws);
1061 loader_x.next();
1062 loader_w.next();
1063 }
1064 }
1065 }
1066
1067 // Store results to device memory
1068 threadgroup_barrier(mem_flags::mem_threadgroup);
1069 if (num_els < BM || num_outs < BN) {
1070 mma_op.store_result_safe(y, N, short2(num_outs, num_els));
1071 } else {
1072 mma_op.store_result(y, N);
1073 }
1074}
1075
1076template <
1077 typename T,
1078 const int group_size,
1079 const int bits,
1080 const int BM = 32,
1081 const int BK = 32,
1082 const int BN = 32>
1083METAL_FUNC void qmm_n_impl(
1084 const device uint32_t* w,
1085 const device T* scales,
1086 const device T* biases,
1087 const device T* x,
1088 device T* y,
1089 threadgroup T* Xs,
1090 threadgroup T* Ws,
1091 const constant int& K,
1092 const constant int& N,
1093 const constant int& M,
1094 uint3 tid [[threadgroup_position_in_grid]],
1095 uint lid [[thread_index_in_threadgroup]],
1096 uint simd_gid [[simdgroup_index_in_threadgroup]],
1097 uint simd_lid [[thread_index_in_simdgroup]]) {
1098 static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
1099 static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
1100
1101 (void)lid;
1102
1103 constexpr int WM = 2;
1104 constexpr int WN = 2;
1105 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
1106 constexpr int BK_padded = (BK + 16 / sizeof(T));
1107 constexpr int BN_padded = (BN + 16 / sizeof(T));
1108 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
1109 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
1110
1111 // Instantiate the appropriate BlockMMA and Loader
1112 using mma_t = mlx::steel::
1113 BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
1114 using loader_x_t = mlx::steel::
1115 BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
1116 using loader_w_t = QuantizedBlockLoader<
1117 T,
1118 BK,
1119 BN,
1120 BN_padded,
1121 0,
1122 WM * WN * SIMD_SIZE,
1123 group_size,
1124 bits>;
1125
1126 auto wl = (const device uint8_t*)w;
1127
1128 // Set the block
1129 const int y_row = tid.y * BM;
1130 const int y_col = tid.x * BN;
1131 x += y_row * K;
1132 wl += y_col * bytes_per_pack / pack_factor;
1133 scales += y_col / group_size;
1134 biases += y_col / group_size;
1135 y += y_row * N + y_col;
1136
1137 // Make the x loader and mma operation
1138 const short num_els = min(BM, M - y_row);
1139 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
1140 loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid);
1141 mma_t mma_op(simd_gid, simd_lid);
1142
1143 if (num_els < BM) {
1144 if ((K % BK) != 0) {
1145 const int k_blocks = K / BK;
1146 for (int k = 0; k < k_blocks; k++) {
1147 threadgroup_barrier(mem_flags::mem_threadgroup);
1148 loader_x.load_safe(short2(BK, num_els));
1149 loader_w.load_unsafe();
1150 threadgroup_barrier(mem_flags::mem_threadgroup);
1151 mma_op.mma(Xs, Ws);
1152 loader_x.next();
1153 loader_w.next();
1154 }
1155 const short num_k = K - k_blocks * BK;
1156 threadgroup_barrier(mem_flags::mem_threadgroup);
1157 loader_x.load_safe(short2(num_k, num_els));
1158 loader_w.load_safe(short2(BN, num_k));
1159 threadgroup_barrier(mem_flags::mem_threadgroup);
1160 mma_op.mma(Xs, Ws);
1161 } else {
1162 for (int k = 0; k < K; k += BK) {
1163 threadgroup_barrier(mem_flags::mem_threadgroup);
1164 loader_x.load_safe(short2(BK, num_els));
1165 loader_w.load_unsafe();
1166 threadgroup_barrier(mem_flags::mem_threadgroup);
1167 mma_op.mma(Xs, Ws);
1168 loader_x.next();
1169 loader_w.next();
1170 }
1171 }
1172 } else {
1173 if ((K % BK) != 0) {
1174 const int k_blocks = K / BK;
1175 for (int k = 0; k < k_blocks; k++) {
1176 threadgroup_barrier(mem_flags::mem_threadgroup);
1177 loader_x.load_unsafe();
1178 loader_w.load_unsafe();
1179 threadgroup_barrier(mem_flags::mem_threadgroup);
1180 mma_op.mma(Xs, Ws);
1181 loader_x.next();
1182 loader_w.next();
1183 }
1184 const short num_k = K - k_blocks * BK;
1185 threadgroup_barrier(mem_flags::mem_threadgroup);
1186 loader_x.load_safe(short2(num_k, BM));
1187 loader_w.load_safe(short2(BN, num_k));
1188 threadgroup_barrier(mem_flags::mem_threadgroup);
1189 mma_op.mma(Xs, Ws);
1190 } else {
1191 for (int k = 0; k < K; k += BK) {
1192 threadgroup_barrier(mem_flags::mem_threadgroup);
1193 loader_x.load_unsafe();
1194 loader_w.load_unsafe();
1195 threadgroup_barrier(mem_flags::mem_threadgroup);
1196 mma_op.mma(Xs, Ws);
1197 loader_x.next();
1198 loader_w.next();
1199 }
1200 }
1201 }
1202
1203 // Store results to device memory
1204 threadgroup_barrier(mem_flags::mem_threadgroup);
1205 if (num_els < BM) {
1206 mma_op.store_result_safe(y, N, short2(BN, num_els));
1207 } else {
1208 mma_op.store_result(y, N);
1209 }
1210}
1211
1212template <typename T>
1214 const device T*& x,
1215 const device uint32_t*& w,
1216 const device T*& scales,
1217 const device T*& biases,
1218 device T*& y,
1219 int output_stride,
1220 const constant int& x_batch_ndims,
1221 const constant int* x_shape,
1222 const constant int64_t* x_strides,
1223 const constant int& w_batch_ndims,
1224 const constant int* w_shape,
1225 const constant int64_t* w_strides,
1226 const constant int64_t* s_strides,
1227 const constant int64_t* b_strides,
1228 uint3 tid [[threadgroup_position_in_grid]]) {
1229 // Set the input/output matrices
1230 uint32_t x_idx = tid.z;
1231 uint32_t w_idx = tid.z;
1232 if (x_batch_ndims == 1) {
1233 x += x_idx * x_strides[0];
1234 } else {
1235 x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1236 }
1237 if (w_batch_ndims == 1) {
1238 w += w_idx * w_strides[0];
1239 scales += w_idx * s_strides[0];
1240 biases += w_idx * b_strides[0];
1241 } else {
1242 ulong3 idx = elem_to_loc_broadcast(
1243 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1244 w += idx.x;
1245 scales += idx.y;
1246 biases += idx.z;
1247 }
1248 y += tid.z * output_stride;
1249}
1250
1251template <typename T>
1253 const device T*& x,
1254 const device uint32_t*& w,
1255 const device T*& scales,
1256 const device T*& biases,
1257 const device uint32_t* lhs_indices,
1258 const device uint32_t* rhs_indices,
1259 device T*& y,
1260 int output_stride,
1261 const constant int& batch_ndims,
1262 const constant int* batch_shape,
1263 const constant int64_t* lhs_strides,
1264 const constant int64_t* rhs_strides,
1265 const constant int& x_batch_ndims,
1266 const constant int* x_shape,
1267 const constant int64_t* x_strides,
1268 const constant int& w_batch_ndims,
1269 const constant int* w_shape,
1270 const constant int64_t* w_strides,
1271 const constant int64_t* s_strides,
1272 const constant int64_t* b_strides,
1273 uint3 tid [[threadgroup_position_in_grid]]) {
1274 // Set the input/output matrices
1275 uint32_t x_idx;
1276 uint32_t w_idx;
1277 if (batch_ndims == 1) {
1278 x_idx = lhs_indices[tid.z * lhs_strides[0]];
1279 w_idx = rhs_indices[tid.z * rhs_strides[0]];
1280 } else {
1281 ulong2 idx = elem_to_loc_broadcast(
1282 tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
1283 x_idx = lhs_indices[idx.x];
1284 w_idx = rhs_indices[idx.y];
1285 }
1286 if (x_batch_ndims == 1) {
1287 x += x_idx * x_strides[0];
1288 } else {
1289 x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1290 }
1291 if (w_batch_ndims == 1) {
1292 w += w_idx * w_strides[0];
1293 scales += w_idx * s_strides[0];
1294 biases += w_idx * b_strides[0];
1295 } else {
1296 ulong3 idx = elem_to_loc_broadcast(
1297 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1298 w += idx.x;
1299 scales += idx.y;
1300 biases += idx.z;
1301 }
1302 y += tid.z * output_stride;
1303}
1304
1305template <typename T, int group_size, int bits, int D, bool batched>
1306[[kernel]] void qmv_quad(
1307 const device uint32_t* w [[buffer(0)]],
1308 const device T* scales [[buffer(1)]],
1309 const device T* biases [[buffer(2)]],
1310 const device T* x [[buffer(3)]],
1311 device T* y [[buffer(4)]],
1312 const constant int& in_vec_size [[buffer(5)]],
1313 const constant int& out_vec_size [[buffer(6)]],
1314 const constant int& x_batch_ndims [[buffer(7)]],
1315 const constant int* x_shape [[buffer(8)]],
1316 const constant int64_t* x_strides [[buffer(9)]],
1317 const constant int& w_batch_ndims [[buffer(10)]],
1318 const constant int* w_shape [[buffer(11)]],
1319 const constant int64_t* w_strides [[buffer(12)]],
1320 const constant int64_t* s_strides [[buffer(13)]],
1321 const constant int64_t* b_strides [[buffer(14)]],
1322 uint3 tid [[threadgroup_position_in_grid]],
1323 uint quad_gid [[quadgroup_index_in_threadgroup]],
1324 uint quad_lid [[thread_index_in_quadgroup]]) {
1325 if (batched) {
1326 int M = x_shape[x_batch_ndims];
1328 x,
1329 w,
1330 scales,
1331 biases,
1332 y,
1333 out_vec_size * M,
1334 x_batch_ndims,
1335 x_shape,
1336 x_strides,
1337 w_batch_ndims,
1338 w_shape,
1339 w_strides,
1340 s_strides,
1341 b_strides,
1342 tid);
1343 }
1345 w,
1346 scales,
1347 biases,
1348 x,
1349 y,
1350 in_vec_size,
1351 out_vec_size,
1352 tid,
1353 quad_gid,
1354 quad_lid);
1355}
1356
1357template <typename T, int group_size, int bits, bool batched>
1358[[kernel]] void qmv_fast(
1359 const device uint32_t* w [[buffer(0)]],
1360 const device T* scales [[buffer(1)]],
1361 const device T* biases [[buffer(2)]],
1362 const device T* x [[buffer(3)]],
1363 device T* y [[buffer(4)]],
1364 const constant int& in_vec_size [[buffer(5)]],
1365 const constant int& out_vec_size [[buffer(6)]],
1366 const constant int& x_batch_ndims [[buffer(7)]],
1367 const constant int* x_shape [[buffer(8)]],
1368 const constant int64_t* x_strides [[buffer(9)]],
1369 const constant int& w_batch_ndims [[buffer(10)]],
1370 const constant int* w_shape [[buffer(11)]],
1371 const constant int64_t* w_strides [[buffer(12)]],
1372 const constant int64_t* s_strides [[buffer(13)]],
1373 const constant int64_t* b_strides [[buffer(14)]],
1374 uint3 tid [[threadgroup_position_in_grid]],
1375 uint simd_gid [[simdgroup_index_in_threadgroup]],
1376 uint simd_lid [[thread_index_in_simdgroup]]) {
1377 if (batched) {
1378 int M = x_shape[x_batch_ndims];
1380 x,
1381 w,
1382 scales,
1383 biases,
1384 y,
1385 out_vec_size * M,
1386 x_batch_ndims,
1387 x_shape,
1388 x_strides,
1389 w_batch_ndims,
1390 w_shape,
1391 w_strides,
1392 s_strides,
1393 b_strides,
1394 tid);
1395 }
1397 w,
1398 scales,
1399 biases,
1400 x,
1401 y,
1402 in_vec_size,
1403 out_vec_size,
1404 tid,
1405 simd_gid,
1406 simd_lid);
1407}
1408
1409template <typename T, const int group_size, const int bits, bool batched>
1410[[kernel]] void qmv(
1411 const device uint32_t* w [[buffer(0)]],
1412 const device T* scales [[buffer(1)]],
1413 const device T* biases [[buffer(2)]],
1414 const device T* x [[buffer(3)]],
1415 device T* y [[buffer(4)]],
1416 const constant int& in_vec_size [[buffer(5)]],
1417 const constant int& out_vec_size [[buffer(6)]],
1418 const constant int& x_batch_ndims [[buffer(7)]],
1419 const constant int* x_shape [[buffer(8)]],
1420 const constant int64_t* x_strides [[buffer(9)]],
1421 const constant int& w_batch_ndims [[buffer(10)]],
1422 const constant int* w_shape [[buffer(11)]],
1423 const constant int64_t* w_strides [[buffer(12)]],
1424 const constant int64_t* s_strides [[buffer(13)]],
1425 const constant int64_t* b_strides [[buffer(14)]],
1426 uint3 tid [[threadgroup_position_in_grid]],
1427 uint simd_gid [[simdgroup_index_in_threadgroup]],
1428 uint simd_lid [[thread_index_in_simdgroup]]) {
1429 if (batched) {
1430 int M = x_shape[x_batch_ndims];
1432 x,
1433 w,
1434 scales,
1435 biases,
1436 y,
1437 out_vec_size * M,
1438 x_batch_ndims,
1439 x_shape,
1440 x_strides,
1441 w_batch_ndims,
1442 w_shape,
1443 w_strides,
1444 s_strides,
1445 b_strides,
1446 tid);
1447 }
1449 w,
1450 scales,
1451 biases,
1452 x,
1453 y,
1454 in_vec_size,
1455 out_vec_size,
1456 tid,
1457 simd_gid,
1458 simd_lid);
1459}
1460
1461template <typename T, const int group_size, const int bits, bool batched>
1462[[kernel]] void qvm(
1463 const device uint32_t* w [[buffer(0)]],
1464 const device T* scales [[buffer(1)]],
1465 const device T* biases [[buffer(2)]],
1466 const device T* x [[buffer(3)]],
1467 device T* y [[buffer(4)]],
1468 const constant int& in_vec_size [[buffer(5)]],
1469 const constant int& out_vec_size [[buffer(6)]],
1470 const constant int& x_batch_ndims [[buffer(7)]],
1471 const constant int* x_shape [[buffer(8)]],
1472 const constant int64_t* x_strides [[buffer(9)]],
1473 const constant int& w_batch_ndims [[buffer(10)]],
1474 const constant int* w_shape [[buffer(11)]],
1475 const constant int64_t* w_strides [[buffer(12)]],
1476 const constant int64_t* s_strides [[buffer(13)]],
1477 const constant int64_t* b_strides [[buffer(14)]],
1478 uint3 tid [[threadgroup_position_in_grid]],
1479 uint simd_gid [[simdgroup_index_in_threadgroup]],
1480 uint simd_lid [[thread_index_in_simdgroup]]) {
1481 if (batched) {
1482 int M = x_shape[x_batch_ndims];
1484 x,
1485 w,
1486 scales,
1487 biases,
1488 y,
1489 out_vec_size * M,
1490 x_batch_ndims,
1491 x_shape,
1492 x_strides,
1493 w_batch_ndims,
1494 w_shape,
1495 w_strides,
1496 s_strides,
1497 b_strides,
1498 tid);
1499 }
1501 w,
1502 scales,
1503 biases,
1504 x,
1505 y,
1506 in_vec_size,
1507 out_vec_size,
1508 tid,
1509 simd_gid,
1510 simd_lid);
1511}
1512
1513template <typename T, const int group_size, const int bits, int split_k = 32>
1514[[kernel]] void qvm_split_k(
1515 const device uint32_t* w [[buffer(0)]],
1516 const device T* scales [[buffer(1)]],
1517 const device T* biases [[buffer(2)]],
1518 const device T* x [[buffer(3)]],
1519 device T* y [[buffer(4)]],
1520 const constant int& in_vec_size [[buffer(5)]],
1521 const constant int& out_vec_size [[buffer(6)]],
1522 const constant int& x_batch_ndims [[buffer(7)]],
1523 const constant int* x_shape [[buffer(8)]],
1524 const constant int64_t* x_strides [[buffer(9)]],
1525 const constant int& w_batch_ndims [[buffer(10)]],
1526 const constant int* w_shape [[buffer(11)]],
1527 const constant int64_t* w_strides [[buffer(12)]],
1528 const constant int64_t* s_strides [[buffer(13)]],
1529 const constant int64_t* b_strides [[buffer(14)]],
1530 const constant int& final_block_size [[buffer(15)]],
1531 uint3 tid [[threadgroup_position_in_grid]],
1532 uint simd_gid [[simdgroup_index_in_threadgroup]],
1533 uint simd_lid [[thread_index_in_simdgroup]]) {
1534 int M = x_shape[x_batch_ndims];
1536 x,
1537 w,
1538 scales,
1539 biases,
1540 y,
1541 out_vec_size * M,
1542 x_batch_ndims,
1543 x_shape,
1544 x_strides,
1545 w_batch_ndims,
1546 w_shape,
1547 w_strides,
1548 s_strides,
1549 b_strides,
1550 tid);
1551
1552 // When (in_vec_size % split_k != 0) the final block needs to be smaller
1553 int in_vec_size_adj =
1554 tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
1555
1557 w,
1558 scales,
1559 biases,
1560 x,
1561 y,
1562 in_vec_size_adj,
1563 out_vec_size,
1564 tid,
1565 simd_gid,
1566 simd_lid);
1567}
1568
1569template <
1570 typename T,
1571 const int group_size,
1572 const int bits,
1573 const bool aligned_N,
1574 const bool batched,
1575 const int BM = 32,
1576 const int BK = 32,
1577 const int BN = 32>
1578[[kernel]] void qmm_t(
1579 const device uint32_t* w [[buffer(0)]],
1580 const device T* scales [[buffer(1)]],
1581 const device T* biases [[buffer(2)]],
1582 const device T* x [[buffer(3)]],
1583 device T* y [[buffer(4)]],
1584 const constant int& K [[buffer(5)]],
1585 const constant int& N [[buffer(6)]],
1586 const constant int& M [[buffer(7)]],
1587 const constant int& x_batch_ndims [[buffer(8)]],
1588 const constant int* x_shape [[buffer(9)]],
1589 const constant int64_t* x_strides [[buffer(10)]],
1590 const constant int& w_batch_ndims [[buffer(11)]],
1591 const constant int* w_shape [[buffer(12)]],
1592 const constant int64_t* w_strides [[buffer(13)]],
1593 const constant int64_t* s_strides [[buffer(14)]],
1594 const constant int64_t* b_strides [[buffer(15)]],
1595 uint3 tid [[threadgroup_position_in_grid]],
1596 uint lid [[thread_index_in_threadgroup]],
1597 uint simd_gid [[simdgroup_index_in_threadgroup]],
1598 uint simd_lid [[thread_index_in_simdgroup]]) {
1599 (void)lid;
1600
1601 constexpr int BK_padded = (BK + 16 / sizeof(T));
1602
1603 threadgroup T Xs[BM * BK_padded];
1604 threadgroup T Ws[BN * BK_padded];
1605
1606 if (batched) {
1608 x,
1609 w,
1610 scales,
1611 biases,
1612 y,
1613 M * N,
1614 x_batch_ndims,
1615 x_shape,
1616 x_strides,
1617 w_batch_ndims,
1618 w_shape,
1619 w_strides,
1620 s_strides,
1621 b_strides,
1622 tid);
1623 }
1625 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1626}
1627
1628template <
1629 typename T,
1630 const int group_size,
1631 const int bits,
1632 const bool batched,
1633 const int BM = 32,
1634 const int BK = 32,
1635 const int BN = 32>
1636[[kernel]] void qmm_n(
1637 const device uint32_t* w [[buffer(0)]],
1638 const device T* scales [[buffer(1)]],
1639 const device T* biases [[buffer(2)]],
1640 const device T* x [[buffer(3)]],
1641 device T* y [[buffer(4)]],
1642 const constant int& K [[buffer(5)]],
1643 const constant int& N [[buffer(6)]],
1644 const constant int& M [[buffer(7)]],
1645 const constant int& x_batch_ndims [[buffer(8)]],
1646 const constant int* x_shape [[buffer(9)]],
1647 const constant int64_t* x_strides [[buffer(10)]],
1648 const constant int& w_batch_ndims [[buffer(11)]],
1649 const constant int* w_shape [[buffer(12)]],
1650 const constant int64_t* w_strides [[buffer(13)]],
1651 const constant int64_t* s_strides [[buffer(14)]],
1652 const constant int64_t* b_strides [[buffer(15)]],
1653 uint3 tid [[threadgroup_position_in_grid]],
1654 uint lid [[thread_index_in_threadgroup]],
1655 uint simd_gid [[simdgroup_index_in_threadgroup]],
1656 uint simd_lid [[thread_index_in_simdgroup]]) {
1657 (void)lid;
1658
1659 constexpr int BK_padded = (BK + 16 / sizeof(T));
1660 constexpr int BN_padded = (BN + 16 / sizeof(T));
1661
1662 threadgroup T Xs[BM * BK_padded];
1663 threadgroup T Ws[BK * BN_padded];
1664
1665 if (batched) {
1667 x,
1668 w,
1669 scales,
1670 biases,
1671 y,
1672 M * N,
1673 x_batch_ndims,
1674 x_shape,
1675 x_strides,
1676 w_batch_ndims,
1677 w_shape,
1678 w_strides,
1679 s_strides,
1680 b_strides,
1681 tid);
1682 }
1683
1685 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1686}
1687
1688template <typename T, int group_size, int bits>
1689[[kernel]] void bs_qmv_fast(
1690 const device uint32_t* w [[buffer(0)]],
1691 const device T* scales [[buffer(1)]],
1692 const device T* biases [[buffer(2)]],
1693 const device T* x [[buffer(3)]],
1694 device T* y [[buffer(4)]],
1695 const constant int& in_vec_size [[buffer(5)]],
1696 const constant int& out_vec_size [[buffer(6)]],
1697 const constant int& x_batch_ndims [[buffer(7)]],
1698 const constant int* x_shape [[buffer(8)]],
1699 const constant int64_t* x_strides [[buffer(9)]],
1700 const constant int& w_batch_ndims [[buffer(10)]],
1701 const constant int* w_shape [[buffer(11)]],
1702 const constant int64_t* w_strides [[buffer(12)]],
1703 const constant int64_t* s_strides [[buffer(13)]],
1704 const constant int64_t* b_strides [[buffer(14)]],
1705 const constant int& batch_ndims [[buffer(15)]],
1706 const constant int* batch_shape [[buffer(16)]],
1707 const device uint32_t* lhs_indices [[buffer(17)]],
1708 const device uint32_t* rhs_indices [[buffer(18)]],
1709 const constant int64_t* lhs_strides [[buffer(19)]],
1710 const constant int64_t* rhs_strides [[buffer(20)]],
1711 uint3 tid [[threadgroup_position_in_grid]],
1712 uint simd_gid [[simdgroup_index_in_threadgroup]],
1713 uint simd_lid [[thread_index_in_simdgroup]]) {
1714 int M = x_shape[x_batch_ndims];
1716 x,
1717 w,
1718 scales,
1719 biases,
1720 lhs_indices,
1721 rhs_indices,
1722 y,
1723 out_vec_size * M,
1724 batch_ndims,
1725 batch_shape,
1726 lhs_strides,
1727 rhs_strides,
1728 x_batch_ndims,
1729 x_shape,
1730 x_strides,
1731 w_batch_ndims,
1732 w_shape,
1733 w_strides,
1734 s_strides,
1735 b_strides,
1736 tid);
1738 w,
1739 scales,
1740 biases,
1741 x,
1742 y,
1743 in_vec_size,
1744 out_vec_size,
1745 tid,
1746 simd_gid,
1747 simd_lid);
1748}
1749
1750template <typename T, int group_size, int bits>
1751[[kernel]] void bs_qmv(
1752 const device uint32_t* w [[buffer(0)]],
1753 const device T* scales [[buffer(1)]],
1754 const device T* biases [[buffer(2)]],
1755 const device T* x [[buffer(3)]],
1756 device T* y [[buffer(4)]],
1757 const constant int& in_vec_size [[buffer(5)]],
1758 const constant int& out_vec_size [[buffer(6)]],
1759 const constant int& x_batch_ndims [[buffer(7)]],
1760 const constant int* x_shape [[buffer(8)]],
1761 const constant int64_t* x_strides [[buffer(9)]],
1762 const constant int& w_batch_ndims [[buffer(10)]],
1763 const constant int* w_shape [[buffer(11)]],
1764 const constant int64_t* w_strides [[buffer(12)]],
1765 const constant int64_t* s_strides [[buffer(13)]],
1766 const constant int64_t* b_strides [[buffer(14)]],
1767 const constant int& batch_ndims [[buffer(15)]],
1768 const constant int* batch_shape [[buffer(16)]],
1769 const device uint32_t* lhs_indices [[buffer(17)]],
1770 const device uint32_t* rhs_indices [[buffer(18)]],
1771 const constant int64_t* lhs_strides [[buffer(19)]],
1772 const constant int64_t* rhs_strides [[buffer(20)]],
1773 uint3 tid [[threadgroup_position_in_grid]],
1774 uint simd_gid [[simdgroup_index_in_threadgroup]],
1775 uint simd_lid [[thread_index_in_simdgroup]]) {
1776 int M = x_shape[x_batch_ndims];
1778 x,
1779 w,
1780 scales,
1781 biases,
1782 lhs_indices,
1783 rhs_indices,
1784 y,
1785 out_vec_size * M,
1786 batch_ndims,
1787 batch_shape,
1788 lhs_strides,
1789 rhs_strides,
1790 x_batch_ndims,
1791 x_shape,
1792 x_strides,
1793 w_batch_ndims,
1794 w_shape,
1795 w_strides,
1796 s_strides,
1797 b_strides,
1798 tid);
1800 w,
1801 scales,
1802 biases,
1803 x,
1804 y,
1805 in_vec_size,
1806 out_vec_size,
1807 tid,
1808 simd_gid,
1809 simd_lid);
1810}
1811
1812template <typename T, int group_size, int bits>
1813[[kernel]] void bs_qvm(
1814 const device uint32_t* w [[buffer(0)]],
1815 const device T* scales [[buffer(1)]],
1816 const device T* biases [[buffer(2)]],
1817 const device T* x [[buffer(3)]],
1818 device T* y [[buffer(4)]],
1819 const constant int& in_vec_size [[buffer(5)]],
1820 const constant int& out_vec_size [[buffer(6)]],
1821 const constant int& x_batch_ndims [[buffer(7)]],
1822 const constant int* x_shape [[buffer(8)]],
1823 const constant int64_t* x_strides [[buffer(9)]],
1824 const constant int& w_batch_ndims [[buffer(10)]],
1825 const constant int* w_shape [[buffer(11)]],
1826 const constant int64_t* w_strides [[buffer(12)]],
1827 const constant int64_t* s_strides [[buffer(13)]],
1828 const constant int64_t* b_strides [[buffer(14)]],
1829 const constant int& batch_ndims [[buffer(15)]],
1830 const constant int* batch_shape [[buffer(16)]],
1831 const device uint32_t* lhs_indices [[buffer(17)]],
1832 const device uint32_t* rhs_indices [[buffer(18)]],
1833 const constant int64_t* lhs_strides [[buffer(19)]],
1834 const constant int64_t* rhs_strides [[buffer(20)]],
1835 uint3 tid [[threadgroup_position_in_grid]],
1836 uint simd_gid [[simdgroup_index_in_threadgroup]],
1837 uint simd_lid [[thread_index_in_simdgroup]]) {
1838 int M = x_shape[x_batch_ndims];
1840 x,
1841 w,
1842 scales,
1843 biases,
1844 lhs_indices,
1845 rhs_indices,
1846 y,
1847 out_vec_size * M,
1848 batch_ndims,
1849 batch_shape,
1850 lhs_strides,
1851 rhs_strides,
1852 x_batch_ndims,
1853 x_shape,
1854 x_strides,
1855 w_batch_ndims,
1856 w_shape,
1857 w_strides,
1858 s_strides,
1859 b_strides,
1860 tid);
1862 w,
1863 scales,
1864 biases,
1865 x,
1866 y,
1867 in_vec_size,
1868 out_vec_size,
1869 tid,
1870 simd_gid,
1871 simd_lid);
1872}
1873
1874template <
1875 typename T,
1876 const int group_size,
1877 const int bits,
1878 const bool aligned_N,
1879 const int BM = 32,
1880 const int BK = 32,
1881 const int BN = 32>
1882[[kernel]] void bs_qmm_t(
1883 const device uint32_t* w [[buffer(0)]],
1884 const device T* scales [[buffer(1)]],
1885 const device T* biases [[buffer(2)]],
1886 const device T* x [[buffer(3)]],
1887 device T* y [[buffer(4)]],
1888 const constant int& K [[buffer(5)]],
1889 const constant int& N [[buffer(6)]],
1890 const constant int& M [[buffer(7)]],
1891 const constant int& x_batch_ndims [[buffer(8)]],
1892 const constant int* x_shape [[buffer(9)]],
1893 const constant int64_t* x_strides [[buffer(10)]],
1894 const constant int& w_batch_ndims [[buffer(11)]],
1895 const constant int* w_shape [[buffer(12)]],
1896 const constant int64_t* w_strides [[buffer(13)]],
1897 const constant int64_t* s_strides [[buffer(14)]],
1898 const constant int64_t* b_strides [[buffer(15)]],
1899 const constant int& batch_ndims [[buffer(16)]],
1900 const constant int* batch_shape [[buffer(17)]],
1901 const device uint32_t* lhs_indices [[buffer(18)]],
1902 const device uint32_t* rhs_indices [[buffer(19)]],
1903 const constant int64_t* lhs_strides [[buffer(20)]],
1904 const constant int64_t* rhs_strides [[buffer(21)]],
1905 uint3 tid [[threadgroup_position_in_grid]],
1906 uint lid [[thread_index_in_threadgroup]],
1907 uint simd_gid [[simdgroup_index_in_threadgroup]],
1908 uint simd_lid [[thread_index_in_simdgroup]]) {
1909 (void)lid;
1910
1911 constexpr int BK_padded = (BK + 16 / sizeof(T));
1912
1913 threadgroup T Xs[BM * BK_padded];
1914 threadgroup T Ws[BN * BK_padded];
1915
1917 x,
1918 w,
1919 scales,
1920 biases,
1921 lhs_indices,
1922 rhs_indices,
1923 y,
1924 M * N,
1925 batch_ndims,
1926 batch_shape,
1927 lhs_strides,
1928 rhs_strides,
1929 x_batch_ndims,
1930 x_shape,
1931 x_strides,
1932 w_batch_ndims,
1933 w_shape,
1934 w_strides,
1935 s_strides,
1936 b_strides,
1937 tid);
1939 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1940}
1941
1942template <
1943 typename T,
1944 const int group_size,
1945 const int bits,
1946 const int BM = 32,
1947 const int BK = 32,
1948 const int BN = 32>
1949[[kernel]] void bs_qmm_n(
1950 const device uint32_t* w [[buffer(0)]],
1951 const device T* scales [[buffer(1)]],
1952 const device T* biases [[buffer(2)]],
1953 const device T* x [[buffer(3)]],
1954 device T* y [[buffer(4)]],
1955 const constant int& K [[buffer(5)]],
1956 const constant int& N [[buffer(6)]],
1957 const constant int& M [[buffer(7)]],
1958 const constant int& x_batch_ndims [[buffer(8)]],
1959 const constant int* x_shape [[buffer(9)]],
1960 const constant int64_t* x_strides [[buffer(10)]],
1961 const constant int& w_batch_ndims [[buffer(11)]],
1962 const constant int* w_shape [[buffer(12)]],
1963 const constant int64_t* w_strides [[buffer(13)]],
1964 const constant int64_t* s_strides [[buffer(14)]],
1965 const constant int64_t* b_strides [[buffer(15)]],
1966 const constant int& batch_ndims [[buffer(16)]],
1967 const constant int* batch_shape [[buffer(17)]],
1968 const device uint32_t* lhs_indices [[buffer(18)]],
1969 const device uint32_t* rhs_indices [[buffer(19)]],
1970 const constant int64_t* lhs_strides [[buffer(20)]],
1971 const constant int64_t* rhs_strides [[buffer(21)]],
1972 uint3 tid [[threadgroup_position_in_grid]],
1973 uint lid [[thread_index_in_threadgroup]],
1974 uint simd_gid [[simdgroup_index_in_threadgroup]],
1975 uint simd_lid [[thread_index_in_simdgroup]]) {
1976 (void)lid;
1977
1978 constexpr int BK_padded = (BK + 16 / sizeof(T));
1979 constexpr int BN_padded = (BN + 16 / sizeof(T));
1980
1981 threadgroup T Xs[BM * BK_padded];
1982 threadgroup T Ws[BK * BN_padded];
1983
1985 x,
1986 w,
1987 scales,
1988 biases,
1989 lhs_indices,
1990 rhs_indices,
1991 y,
1992 M * N,
1993 batch_ndims,
1994 batch_shape,
1995 lhs_strides,
1996 rhs_strides,
1997 x_batch_ndims,
1998 x_shape,
1999 x_strides,
2000 w_batch_ndims,
2001 w_shape,
2002 w_strides,
2003 s_strides,
2004 b_strides,
2005 tid);
2007 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
2008}
2009
2010template <typename T, const int group_size, const int bits>
2011[[kernel]] void affine_quantize(
2012 const device T* w [[buffer(0)]],
2013 device uint8_t* out [[buffer(1)]],
2014 device T* scales [[buffer(2)]],
2015 device T* biases [[buffer(3)]],
2016 uint2 index [[thread_position_in_grid]],
2017 uint2 grid_dim [[threads_per_grid]]) {
2018 constexpr T eps = T(1e-7);
2019 constexpr int simd_size = 32;
2020 constexpr T n_bins = (1 << bits) - 1;
2021 constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
2022 constexpr int values_per_reduce = group_size / simd_size;
2023 constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
2024 constexpr int writes_per_pack =
2025 writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
2026 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
2027 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
2028
2029 static_assert(
2030 group_size % simd_size == 0,
2031 "Group size must be divisible by simd size.");
2032
2033 size_t offset = index.x + grid_dim.x * size_t(index.y);
2034 size_t in_index = offset * values_per_reduce;
2035 size_t out_index = power_of_2_bits
2036 ? offset * writes_per_pack
2037 : offset * bytes_per_pack / writes_per_reduce;
2038
2039 T w_thread[values_per_reduce];
2040 T w_min = Limits<T>::max;
2041 T w_max = 0;
2042
2043#pragma clang loop unroll(full)
2044 for (int i = 0; i < values_per_reduce; i++) {
2045 T val = w[in_index + i];
2046 w_thread[i] = val;
2047 w_min = min(w_min, val);
2048 w_max = max(w_max, val);
2049 }
2050
2051 w_min = simd_min(w_min);
2052 w_max = simd_max(w_max);
2053
2054 T scale = max((w_max - w_min) / n_bins, eps);
2055 bool side = abs(w_min) > abs(w_max);
2056 scale = side ? scale : -scale;
2057 T edge = side ? w_min : w_max;
2058 T q0 = round(edge / scale);
2059 bool at_zero = q0 == 0.0f;
2060 scale = at_zero ? scale : edge / q0;
2061 T bias = at_zero ? T(0) : edge;
2062
2063 // Write out the scales and biases
2064 size_t gindex = in_index / group_size;
2065 if (in_index % group_size == 0) {
2066 scales[gindex] = scale;
2067 biases[gindex] = bias;
2068 }
2069
2070 // We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
2071 uint32_t output = 0;
2072
2073#pragma clang loop unroll(full)
2074 for (int i = 0; i < values_per_reduce; i++) {
2075 uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
2076 if (bits == 8) {
2077 output = val;
2078 } else {
2079 output += val << (bits * (i % packs_per_int));
2080 }
2081
2082 if (packs_per_int < values_per_reduce &&
2083 i % packs_per_int == packs_per_int - 1) {
2084 out[out_index + i / packs_per_int] = output;
2085 output = 0;
2086 } else {
2087#pragma clang loop unroll(full)
2088 for (int j = 1; j < writes_per_reduce; j++) {
2089 uint8_t sval = simd_shuffle_down(val, j);
2090 output += sval << (bits * (j * values_per_reduce + i));
2091 }
2092 }
2093 }
2094 if (bits == 3 || bits == 6) {
2095 if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) {
2096 out[out_index] = output & 0xff;
2097 out[out_index + 1] = (output & 0xff00) >> 8;
2098 out[out_index + 2] = (output & 0xff0000) >> 16;
2099 }
2100 } else {
2101 if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
2102 out[out_index / writes_per_reduce] = output;
2103 }
2104 }
2105}
2106
2107template <typename T, const int group_size, const int bits>
2108[[kernel]] void affine_dequantize(
2109 const device uint8_t* w [[buffer(0)]],
2110 const device T* scales [[buffer(1)]],
2111 const device T* biases [[buffer(2)]],
2112 device T* out [[buffer(3)]],
2113 uint2 index [[thread_position_in_grid]],
2114 uint2 grid_dim [[threads_per_grid]]) {
2115 constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
2116 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
2117 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
2118
2119 size_t offset = index.x + grid_dim.x * size_t(index.y);
2120 size_t oindex = offset * packs_per_int;
2121 size_t gindex = oindex / group_size;
2122 T scale = scales[gindex];
2123 T bias = biases[gindex];
2124
2125 out += oindex;
2126
2127 if (bits == 3) {
2128 w += offset * bytes_per_pack;
2129 out[0] = (w[0] & 0x7) * scale + bias;
2130 out[1] = ((w[0] & 0x38) >> 3) * scale + bias;
2131 out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
2132 out[3] = ((w[1] & 0xe) >> 1) * scale + bias;
2133 out[4] = ((w[1] & 0x70) >> 4) * scale + bias;
2134 out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
2135 out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
2136 out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
2137
2138 } else if (bits == 6) {
2139 w += offset * bytes_per_pack;
2140 out[0] = (w[0] & 0x3f) * scale + bias;
2141 out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
2142 out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
2143 out[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
2144 } else {
2145 uint val = w[offset];
2146#pragma clang loop unroll(full)
2147 for (int i = 0; i < packs_per_int; i++) {
2148 uint8_t d;
2149 if (bits == 2) {
2150 d = (val >> (bits * i)) & 0x03;
2151 } else if (bits == 4) {
2152 d = (val >> (bits * i)) & 0x0f;
2153 } else if (bits == 8) {
2154 d = val;
2155 }
2156 out[i] = scale * d + bias;
2157 }
2158 }
2159}
static constant constexpr const uint8_t simd_size
Definition ops.h:22
METAL_FUNC ulong2 elem_to_loc_broadcast(uint elem, constant const int *shape, constant const int64_t *a_strides, constant const int64_t *b_strides, int ndim)
Definition utils.h:7
METAL_FUNC IdxT elem_to_loc(IdxT elem, constant const int *shape, constant const int64_t *strides, int ndim)
Definition utils.h:93
#define MLX_MTL_CONST
Definition gemv_masked.h:7
Definition bf16_math.h:226
METAL_FUNC bfloat16_t simd_max(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t round(bfloat16_t x)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t simd_sum(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t abs(bfloat16_t x)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t simd_min(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta)
Definition bf16_math.h:377
array bits(const Shape &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
#define MLX_MTL_CONST
Definition quantized.h:8
U qdot_safe(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum, int N)
Definition quantized.h:225
METAL_FUNC void qmm_n_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1083
METAL_FUNC void qvm_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const int in_vec_size, const int out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:843
void bs_qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant int64_t *lhs_strides, const constant int64_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1813
void bs_qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant int64_t *lhs_strides, const constant int64_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1949
void qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1358
void bs_qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant int64_t *lhs_strides, const constant int64_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1689
METAL_FUNC void adjust_matrix_offsets(const device T *&x, const device uint32_t *&w, const device T *&scales, const device T *&biases, device T *&y, int output_stride, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid)
Definition quantized.h:1213
void affine_quantize(const device T *w, device uint8_t *out, device T *scales, device T *biases, uint2 index, uint2 grid_dim)
Definition quantized.h:2011
void qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1462
void affine_dequantize(const device uint8_t *w, const device T *scales, const device T *biases, device T *out, uint2 index, uint2 grid_dim)
Definition quantized.h:2108
static constant constexpr const int SIMD_SIZE
Definition quantized.h:10
void bs_qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant int64_t *lhs_strides, const constant int64_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1882
void qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1636
static constant constexpr const int QUAD_SIZE
Definition quantized.h:11
void qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1410
void qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1578
U load_vector(const device T *x, thread U *x_thread)
Definition quantized.h:14
METAL_FUNC void qmv_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:688
void qmv_quad(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:1306
U load_vector_safe(const device T *x, thread U *x_thread, int N)
Definition quantized.h:77
void qvm_split_k(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &final_block_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1514
void bs_qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant int64_t *lhs_strides, const constant int64_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1751
U qdot(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum)
Definition quantized.h:145
METAL_FUNC void qmv_fast_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:620
METAL_FUNC void qmv_quad_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:563
void qouter(const thread uint8_t *w, U x, U scale, U bias, thread U *result)
Definition quantized.h:307
void dequantize(const device uint8_t *w, U scale, U bias, threadgroup U *w_local)
Definition quantized.h:372
METAL_FUNC void qmm_t_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:958
U type
Definition utils.h:417
static const constant U max
Definition utils.h:24
Definition quantized.h:443
const int group_stride
Definition quantized.h:464
static constant constexpr const short BCOLS_PACKED
Definition quantized.h:456
const device T * biases
Definition quantized.h:473
short group_step_cnt
Definition quantized.h:463
static constant constexpr const short group_steps
Definition quantized.h:459
const short thread_idx
Definition quantized.h:466
QuantizedBlockLoader(const device uint8_t *src_, const device T *scales_, const device T *biases_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition quantized.h:475
const device T * scales
Definition quantized.h:472
static constant constexpr const short n_reads
Definition quantized.h:457
void next()
Definition quantized.h:541
void load_safe(short2 src_tile_dim) const
Definition quantized.h:511
const int src_ld
Definition quantized.h:461
const short bi
Definition quantized.h:467
void load_unsafe() const
Definition quantized.h:498
static constant constexpr const short pack_factor
Definition quantized.h:454
threadgroup T * dst
Definition quantized.h:470
const device uint8_t * src
Definition quantized.h:471
const int tile_stride
Definition quantized.h:462
static constant constexpr const short bytes_per_pack
Definition quantized.h:455
const short bj
Definition quantized.h:468
Definition loader.h:25