MLX
Loading...
Searching...
No Matches
quantized.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#include <metal_simdgroup>
4#include <metal_stdlib>
5
6using namespace metal;
7
8#define MLX_MTL_CONST static constant constexpr const
9
12
13template <typename T, typename U, int values_per_thread, int bits>
14inline U load_vector(const device T* x, thread U* x_thread) {
15 static_assert(
16 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
17 "Template undefined for bits not in {2, 3, 4, 6, 8}");
18
19 U sum = 0;
20
21 if (bits == 2) {
22 for (int i = 0; i < values_per_thread; i += 4) {
23 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
24 x_thread[i] = x[i];
25 x_thread[i + 1] = x[i + 1] / 4.0f;
26 x_thread[i + 2] = x[i + 2] / 16.0f;
27 x_thread[i + 3] = x[i + 3] / 64.0f;
28 }
29 }
30
31 else if (bits == 3) {
32 for (int i = 0; i < values_per_thread; i += 8) {
33 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
34 x[i + 6] + x[i + 7];
35 x_thread[i] = x[i];
36 x_thread[i + 1] = x[i + 1] / 8.0f;
37 x_thread[i + 2] = x[i + 2] / 64.0f;
38 x_thread[i + 3] = x[i + 3] / 2.0f;
39 x_thread[i + 4] = x[i + 4] / 16.0f;
40 x_thread[i + 5] = x[i + 5] / 128.0f;
41 x_thread[i + 6] = x[i + 6] / 4.0f;
42 x_thread[i + 7] = x[i + 7] / 32.0f;
43 }
44 }
45
46 else if (bits == 4) {
47 for (int i = 0; i < values_per_thread; i += 4) {
48 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
49 x_thread[i] = x[i];
50 x_thread[i + 1] = x[i + 1] / 16.0f;
51 x_thread[i + 2] = x[i + 2] / 256.0f;
52 x_thread[i + 3] = x[i + 3] / 4096.0f;
53 }
54 }
55
56 else if (bits == 6) {
57 for (int i = 0; i < values_per_thread; i += 4) {
58 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
59 x_thread[i] = x[i];
60 x_thread[i + 1] = x[i + 1] / 64.0f;
61 x_thread[i + 2] = x[i + 2] / 16.0f;
62 x_thread[i + 3] = x[i + 3] / 4.0f;
63 }
64 }
65
66 else if (bits == 8) {
67 for (int i = 0; i < values_per_thread; i++) {
68 sum += x[i];
69 x_thread[i] = x[i];
70 }
71 }
72
73 return sum;
74}
75
76template <typename T, typename U, int values_per_thread, int bits>
77inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
78 static_assert(
79 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
80 "Template undefined for bits not in {2, 3, 4, 6, 8}");
81
82 U sum = 0;
83
84 if (bits == 2) {
85 for (int i = 0; i < N; i += 4) {
86 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
87 x_thread[i] = x[i];
88 x_thread[i + 1] = x[i + 1] / 4.0f;
89 x_thread[i + 2] = x[i + 2] / 16.0f;
90 x_thread[i + 3] = x[i + 3] / 64.0f;
91 }
92 }
93
94 else if (bits == 3) {
95 for (int i = 0; i < N; i += 8) {
96 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
97 x[i + 6] + x[i + 7];
98
99 x_thread[i] = x[i];
100 x_thread[i + 1] = x[i + 1] / 8.0f;
101 x_thread[i + 2] = x[i + 2] / 64.0f;
102 x_thread[i + 3] = x[i + 3] / 2.0f;
103 x_thread[i + 4] = x[i + 4] / 16.0f;
104 x_thread[i + 5] = x[i + 5] / 128.0f;
105 x_thread[i + 6] = x[i + 6] / 4.0f;
106 x_thread[i + 7] = x[i + 7] / 32.0f;
107 }
108 }
109
110 else if (bits == 4) {
111 for (int i = 0; i < N; i += 4) {
112 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
113 x_thread[i] = x[i];
114 x_thread[i + 1] = x[i + 1] / 16.0f;
115 x_thread[i + 2] = x[i + 2] / 256.0f;
116 x_thread[i + 3] = x[i + 3] / 4096.0f;
117 }
118 }
119
120 else if (bits == 6) {
121 for (int i = 0; i < N; i += 4) {
122 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
123 x_thread[i] = x[i];
124 x_thread[i + 1] = x[i + 1] / 64.0f;
125 x_thread[i + 2] = x[i + 2] / 16.0f;
126 x_thread[i + 3] = x[i + 3] / 4.0f;
127 }
128 }
129
130 else if (bits == 8) {
131 for (int i = 0; i < N; i++) {
132 sum += x[i];
133 x_thread[i] = x[i];
134 }
135 }
136
137 for (int i = N; i < values_per_thread; i++) {
138 x_thread[i] = 0;
139 }
140
141 return sum;
142}
143
144template <typename U, int values_per_thread, int bits>
145inline U qdot(
146 const device uint8_t* w,
147 const thread U* x_thread,
148 U scale,
149 U bias,
150 U sum) {
151 static_assert(
152 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
153 "Template undefined for bits not in {2, 3, 4, 6, 8}");
154
155 U accum = 0;
156
157 if (bits == 2) {
158 for (int i = 0; i < (values_per_thread / 4); i++) {
159 accum +=
160 (x_thread[4 * i] * (w[i] & 0x03) +
161 x_thread[4 * i + 1] * (w[i] & 0x0c) +
162 x_thread[4 * i + 2] * (w[i] & 0x30) +
163 x_thread[4 * i + 3] * (w[i] & 0xc0));
164 }
165 }
166
167 else if (bits == 3) {
168 for (int i = 0; i < (values_per_thread / 8); i++) {
169 x_thread += 8 * i;
170 w += 3 * i;
171
172 accum += (w[0] & 0x07) * x_thread[0];
173 accum += (w[0] & 0x38) * x_thread[1];
174 accum += (w[0] & 0xc0) * x_thread[2];
175 accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
176
177 accum += (w[1] & 0x0e) * x_thread[3];
178 accum += (w[1] & 0x70) * x_thread[4];
179 accum += (w[1] & 0x80) * x_thread[5];
180 accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
181
182 accum += (w[2] & 0x1c) * x_thread[6];
183 accum += (w[2] & 0xe0) * x_thread[7];
184 }
185 }
186
187 else if (bits == 4) {
188 const device uint16_t* ws = (const device uint16_t*)w;
189 for (int i = 0; i < (values_per_thread / 4); i++) {
190 accum +=
191 (x_thread[4 * i] * (ws[i] & 0x000f) +
192 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
193 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
194 x_thread[4 * i + 3] * (ws[i] & 0xf000));
195 }
196 }
197
198 else if (bits == 6) {
199 for (int i = 0; i < (values_per_thread / 4); i++) {
200 x_thread += 4 * i;
201 w += 3 * i;
202
203 accum += (w[0] & 0x3f) * x_thread[0];
204
205 accum += (w[0] & 0xc0) * x_thread[1];
206 accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
207
208 accum += (w[1] & 0xf0) * x_thread[2];
209 accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
210
211 accum += (w[2] & 0xfc) * x_thread[3];
212 }
213 }
214
215 else if (bits == 8) {
216 for (int i = 0; i < values_per_thread; i++) {
217 accum += x_thread[i] * w[i];
218 }
219 }
220
221 return scale * accum + sum * bias;
222}
223
224template <typename U, int values_per_thread, int bits>
225inline U qdot_safe(
226 const device uint8_t* w,
227 const thread U* x_thread,
228 U scale,
229 U bias,
230 U sum,
231 int N) {
232 static_assert(
233 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
234 "Template undefined for bits not in {2, 3, 4, 6, 8}");
235
236 U accum = 0;
237
238 if (bits == 2) {
239 for (int i = 0; i < (N / 4); i++) {
240 accum +=
241 (x_thread[4 * i] * (w[i] & 0x03) +
242 x_thread[4 * i + 1] * (w[i] & 0x0c) +
243 x_thread[4 * i + 2] * (w[i] & 0x30) +
244 x_thread[4 * i + 3] * (w[i] & 0xc0));
245 }
246 }
247
248 else if (bits == 3) {
249 for (int i = 0; i < (N / 8); i++) {
250 x_thread += 8 * i;
251 w += 3 * i;
252
253 accum += (w[0] & 0x07) * x_thread[0];
254 accum += (w[0] & 0x38) * x_thread[1];
255 accum += (w[0] & 0xc0) * x_thread[2];
256 accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
257
258 accum += (w[1] & 0x0e) * x_thread[3];
259 accum += (w[1] & 0x70) * x_thread[4];
260 accum += (w[1] & 0x80) * x_thread[5];
261 accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
262
263 accum += (w[2] & 0x1c) * x_thread[6];
264 accum += (w[2] & 0xe0) * x_thread[7];
265 }
266 }
267
268 else if (bits == 4) {
269 const device uint16_t* ws = (const device uint16_t*)w;
270 for (int i = 0; i < (N / 4); i++) {
271 accum +=
272 (x_thread[4 * i] * (ws[i] & 0x000f) +
273 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
274 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
275 x_thread[4 * i + 3] * (ws[i] & 0xf000));
276 }
277 }
278
279 else if (bits == 6) {
280 for (int i = 0; i < (N / 4); i++) {
281 x_thread += 4 * i;
282 w += 3 * i;
283
284 accum += (w[0] & 0x3f) * x_thread[0];
285
286 accum += (w[0] & 0xc0) * x_thread[1];
287 accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
288
289 accum += (w[1] & 0xf0) * x_thread[2];
290 accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
291
292 accum += (w[2] & 0xfc) * x_thread[3];
293 }
294 }
295
296 else if (bits == 8) {
297 for (int i = 0; i < N; i++) {
298 accum += x_thread[i] * w[i];
299 }
300 }
301
302 return scale * accum + sum * bias;
303}
304
305template <typename U, int values_per_thread, int bits>
306inline void
307qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
308 static_assert(
309 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
310 "Template undefined for bits not in {2, 3, 4, 6, 8}");
311
312 if (bits == 2) {
313 U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
314 for (int i = 0; i < (values_per_thread / 4); i++) {
315 result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
316 result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
317 result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
318 result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
319 }
320 }
321
322 else if (bits == 3) {
323 for (int i = 0; i < (values_per_thread / 8); i++) {
324 uint8_t w0 = w[3 * i];
325 uint8_t w1 = w[3 * i + 1];
326 uint8_t w2 = w[3 * i + 2];
327
328 result[8 * i] += x * ((w0 & 0x7) * scale + bias);
329 result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);
330 result[8 * i + 2] +=
331 x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);
332 result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);
333 result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);
334 result[8 * i + 5] +=
335 x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);
336 result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);
337 result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);
338 }
339 }
340
341 else if (bits == 4) {
342 U s[2] = {scale, scale / 16.0f};
343 for (int i = 0; i < (values_per_thread / 2); i++) {
344 result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
345 result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
346 }
347
348 } else if (bits == 6) {
349 for (int i = 0; i < (values_per_thread / 4); i++) {
350 uint8_t w0 = w[3 * i];
351 uint8_t w1 = w[3 * i + 1];
352 uint8_t w2 = w[3 * i + 2];
353
354 result[4 * i] += x * ((w0 & 0x3f) * scale + bias);
355 result[4 * i + 1] +=
356 x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);
357 result[4 * i + 2] +=
358 x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);
359 result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);
360 }
361 }
362
363 else if (bits == 8) {
364 for (int i = 0; i < values_per_thread; i++) {
365 result[i] += x * (scale * w[i] + bias);
366 }
367 }
368}
369
370template <typename U, int N, int bits>
371inline void
372dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
373 static_assert(
374 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
375 "Template undefined for bits not in {2, 3, 4, 6, 8}");
376
377 if (bits == 2) {
378 U s[4] = {
379 scale,
380 scale / static_cast<U>(4.0f),
381 scale / static_cast<U>(16.0f),
382 scale / static_cast<U>(64.0f)};
383 for (int i = 0; i < (N / 4); i++) {
384 w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
385 w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
386 w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
387 w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
388 }
389 }
390
391 else if (bits == 3) {
392 for (int i = 0; i < (N / 8); i++) {
393 w_local += 8 * i;
394 w += 3 * i;
395
396 w_local[0] = (w[0] & 0x7) * scale + bias;
397 w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;
398 w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
399 w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;
400 w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;
401 w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
402 w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
403 w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
404 }
405 }
406
407 else if (bits == 4) {
408 U s[2] = {scale, scale / static_cast<U>(16.0f)};
409 for (int i = 0; i < (N / 2); i++) {
410 w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
411 w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
412 }
413 }
414
415 else if (bits == 6) {
416 for (int i = 0; i < (N / 4); i++) {
417 w_local += 4 * i;
418 w += 3 * i;
419
420 w_local[0] = (w[0] & 0x3f) * scale + bias;
421 w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
422 w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
423 w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
424 }
425 }
426
427 else if (bits == 8) {
428 for (int i = 0; i < N; i++) {
429 w_local[i] = scale * w[i] + bias;
430 }
431 }
432}
433
434template <
435 typename T,
436 short BROWS,
437 short BCOLS,
438 short dst_ld,
439 short reduction_dim,
440 short tgp_size,
441 short group_size,
442 short bits>
444 static_assert(
445 BCOLS <= group_size,
446 "The group size should be larger than the columns");
447 static_assert(
448 group_size % BCOLS == 0,
449 "The group size should be divisible by the columns");
450 static_assert(
451 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
452 "Template undefined for bits not in {2, 3, 4, 6, 8}");
453
454 MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
455 MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
458 (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
459 MLX_MTL_CONST short group_steps = group_size / BCOLS;
460
461 const int src_ld;
462 const int tile_stride;
464 const int group_stride;
465
466 const short thread_idx;
467 const short bi;
468 const short bj;
469
470 threadgroup T* dst;
471 const device uint8_t* src;
472 const device T* scales;
473 const device T* biases;
474
476 const device uint8_t* src_,
477 const device T* scales_,
478 const device T* biases_,
479 const int src_ld_,
480 threadgroup T* dst_,
481 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
482 ushort simd_lane_id [[thread_index_in_simdgroup]])
483 : src_ld(src_ld_),
485 reduction_dim ? BCOLS_PACKED * bytes_per_pack
486 : BROWS * src_ld * bytes_per_pack / pack_factor),
488 group_stride(BROWS * src_ld / group_size),
489 thread_idx(simd_group_id * 32 + simd_lane_id),
492 dst(dst_ + bi * dst_ld + bj * pack_factor),
493 src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
495 scales(scales_ + bi * src_ld / group_size),
496 biases(biases_ + bi * src_ld / group_size) {}
497
498 void load_unsafe() const {
499 if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
500 return;
501 }
502
503 T scale = *scales;
504 T bias = *biases;
505 for (int i = 0; i < n_reads; i++) {
507 src + i * bytes_per_pack, scale, bias, dst + i * pack_factor);
508 }
509 }
510
511 void load_safe(short2 src_tile_dim) const {
512 if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
513 return;
514 }
515
516 if (reduction_dim == 1 && bi >= src_tile_dim.y) {
517 for (int i = 0; i < n_reads * pack_factor; i++) {
518 dst[i] = T(0);
519 }
520 return;
521 }
522
523 if (reduction_dim == 0 && bi >= src_tile_dim.x) {
524 for (int i = 0; i < n_reads * pack_factor; i++) {
525 dst[i] = T(0);
526 }
527 return;
528 }
529
530 T scale = *scales;
531 T bias = *biases;
532 for (int i = 0; i < n_reads; i++) {
534 (device uint8_t*)(src + i * bytes_per_pack),
535 scale,
536 bias,
537 dst + i * pack_factor);
538 }
539 }
540
541 void next() {
542 src += tile_stride;
543 if (reduction_dim == 1) {
544 if (group_steps > 1) {
547 group_step_cnt = 0;
548 scales++;
549 biases++;
550 }
551 } else {
552 scales++;
553 biases++;
554 }
555 } else {
558 }
559 }
560};
561
562template <typename T, int group_size, int bits, int D>
563METAL_FUNC void qmv_quad_impl(
564 const device uint32_t* w,
565 const device T* scales,
566 const device T* biases,
567 const device T* x,
568 device T* y,
569 constant int& in_vec_size,
570 const constant int& out_vec_size,
571 uint3 tid [[threadgroup_position_in_grid]],
572 uint quad_gid [[quadgroup_index_in_threadgroup]],
573 uint quad_lid [[thread_index_in_quadgroup]]) {
574 constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
575 constexpr int pack_factor = 32 / bits;
576 constexpr int values_per_thread = D / QUAD_SIZE;
577 constexpr int packs_per_thread = values_per_thread / pack_factor;
578 constexpr int scale_step_per_thread = group_size / values_per_thread;
579 constexpr int results_per_quadgroup = 8;
580
581 typedef float U;
582
583 thread U x_thread[values_per_thread];
584 thread U result[results_per_quadgroup] = {0};
585
586 // Adjust positions
587 const int in_vec_size_w = in_vec_size / pack_factor;
588 const int in_vec_size_g = in_vec_size / group_size;
589 const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
590
591 w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
592 scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
593 biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
594 x += tid.y * in_vec_size + quad_lid * values_per_thread;
595 y += tid.y * out_vec_size + out_row;
596
598
599 for (int row = 0; row < results_per_quadgroup; row++) {
600 auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
601 const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
602 const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
603
604 U s = sl[0];
605 U b = bl[0];
606 if (row * quads_per_simd + out_row < out_vec_size) {
607 result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
608 }
609 }
610
611 for (int row = 0; row < results_per_quadgroup; row++) {
612 result[row] = quad_sum(result[row]);
613 if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
614 y[row * quads_per_simd] = static_cast<T>(result[row]);
615 }
616 }
617}
618
619template <typename T, int group_size, int bits>
620METAL_FUNC void qmv_fast_impl(
621 const device uint32_t* w,
622 const device T* scales,
623 const device T* biases,
624 const device T* x,
625 device T* y,
626 const constant int& in_vec_size,
627 const constant int& out_vec_size,
628 uint3 tid [[threadgroup_position_in_grid]],
629 uint simd_gid [[simdgroup_index_in_threadgroup]],
630 uint simd_lid [[thread_index_in_simdgroup]]) {
631 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
632 constexpr int packs_per_thread = bits == 2 ? 1 : 2;
633 constexpr int num_simdgroups = 2;
634 constexpr int results_per_simdgroup = 4;
635 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
636 constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
637 constexpr int values_per_thread = pack_factor * packs_per_thread;
638 constexpr int block_size = values_per_thread * SIMD_SIZE;
639 constexpr int scale_step_per_thread = group_size / values_per_thread;
640
641 const device uint8_t* ws = (const device uint8_t*)w;
642
643 typedef float U;
644
645 thread U x_thread[values_per_thread];
646 thread U result[results_per_simdgroup] = {0};
647
648 // Adjust positions
649 const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
650 const int in_vec_size_g = in_vec_size / group_size;
651 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
652 simd_gid * results_per_simdgroup;
653
654 ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
655 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
656 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
657 x += tid.y * in_vec_size + simd_lid * values_per_thread;
658 y += tid.y * out_vec_size + out_row;
659
660 for (int k = 0; k < in_vec_size; k += block_size) {
662
663 for (int row = 0; row < results_per_simdgroup; row++) {
664 auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
665 const device T* sl = scales + row * in_vec_size_g;
666 const device T* bl = biases + row * in_vec_size_g;
667
668 U s = sl[0];
669 U b = bl[0];
670 result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
671 }
672
673 ws += block_size * bytes_per_pack / pack_factor;
674 scales += block_size / group_size;
675 biases += block_size / group_size;
676 x += block_size;
677 }
678
679 for (int row = 0; row < results_per_simdgroup; row++) {
680 result[row] = simd_sum(result[row]);
681 if (simd_lid == 0) {
682 y[row] = static_cast<T>(result[row]);
683 }
684 }
685}
686
687template <typename T, int group_size, int bits>
688METAL_FUNC void qmv_impl(
689 const device uint32_t* w,
690 const device T* scales,
691 const device T* biases,
692 const device T* x,
693 device T* y,
694 const constant int& in_vec_size,
695 const constant int& out_vec_size,
696 uint3 tid [[threadgroup_position_in_grid]],
697 uint simd_gid [[simdgroup_index_in_threadgroup]],
698 uint simd_lid [[thread_index_in_simdgroup]]) {
699 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
700 constexpr int num_simdgroups = 2;
701 constexpr int results_per_simdgroup = 4;
702 constexpr int packs_per_thread = 1;
703 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
704 constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
705 constexpr int values_per_thread = pack_factor * packs_per_thread;
706 constexpr int block_size = values_per_thread * SIMD_SIZE;
707 constexpr int scale_step_per_thread = group_size / values_per_thread;
708
709 const device uint8_t* ws = (const device uint8_t*)w;
710
711 typedef float U;
712
713 thread U x_thread[values_per_thread];
714 thread U result[results_per_simdgroup] = {0};
715
716 // Adjust positions
717 const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
718 const int in_vec_size_g = in_vec_size / group_size;
719 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
720 simd_gid * results_per_simdgroup;
721 const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
722
723 if (out_row >= out_vec_size) {
724 return;
725 }
726
727 // In this case we need to properly guard all our reads because there isn't
728 // even 1 tile in the matrix
729 if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
730 ws +=
731 out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
732 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
733 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
734 x += tid.y * in_vec_size + simd_lid * values_per_thread;
735 y += tid.y * out_vec_size + out_row;
736
737 int k = 0;
738 for (; k < in_vec_size - block_size; k += block_size) {
740
741 for (int row = 0; out_row + row < out_vec_size; row++) {
742 auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
743 const device T* sl = scales + row * in_vec_size_g;
744 const device T* bl = biases + row * in_vec_size_g;
745
746 U s = sl[0];
747 U b = bl[0];
748 result[row] +=
749 qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
750 }
751
752 ws += block_size * bytes_per_pack / pack_factor;
753 scales += block_size / group_size;
754 biases += block_size / group_size;
755 x += block_size;
756 }
757 const int remaining = clamp(
758 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
759 0,
760 values_per_thread);
761 if (remaining > 0) {
763 x, x_thread, remaining);
764
765 for (int row = 0; out_row + row < out_vec_size; row++) {
766 auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
767 const device T* sl = scales + row * in_vec_size_g;
768 const device T* bl = biases + row * in_vec_size_g;
769
770 U s = sl[0];
771 U b = bl[0];
772 result[row] +=
773 qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
774 }
775 }
776
777 for (int row = 0; out_row + row < out_vec_size; row++) {
778 result[row] = simd_sum(result[row]);
779 if (simd_lid == 0) {
780 y[row] = static_cast<T>(result[row]);
781 }
782 }
783 }
784
785 // In this case the last tile is moved back to redo some output values
786 else {
787 ws += used_out_row * in_vec_size_w +
788 simd_lid * packs_per_thread * bytes_per_pack;
789 scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
790 biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
791 x += tid.y * in_vec_size + simd_lid * values_per_thread;
792 y += tid.y * out_vec_size + used_out_row;
793
794 int k = 0;
795 for (; k < in_vec_size - block_size; k += block_size) {
797
798 for (int row = 0; row < results_per_simdgroup; row++) {
799 auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
800 const device T* sl = scales + row * in_vec_size_g;
801 const device T* bl = biases + row * in_vec_size_g;
802
803 U s = sl[0];
804 U b = bl[0];
805 result[row] +=
806 qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
807 }
808
809 ws += block_size * bytes_per_pack / pack_factor;
810 scales += block_size / group_size;
811 biases += block_size / group_size;
812 x += block_size;
813 }
814 const int remaining = clamp(
815 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
816 0,
817 values_per_thread);
818 if (remaining > 0) {
820 x, x_thread, remaining);
821
822 for (int row = 0; row < results_per_simdgroup; row++) {
823 auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
824 const device T* sl = scales + row * in_vec_size_g;
825 const device T* bl = biases + row * in_vec_size_g;
826
827 U s = sl[0];
828 U b = bl[0];
830 wl, x_thread, s, b, sum, remaining);
831 }
832 }
833 for (int row = 0; row < results_per_simdgroup; row++) {
834 result[row] = simd_sum(result[row]);
835 if (simd_lid == 0) {
836 y[row] = static_cast<T>(result[row]);
837 }
838 }
839 }
840}
841
842template <typename T, const int group_size, const int bits>
843METAL_FUNC void qvm_impl(
844 const device uint32_t* w,
845 const device T* scales,
846 const device T* biases,
847 const device T* x,
848 device T* y,
849 const int in_vec_size,
850 const int out_vec_size,
851 uint3 tid [[threadgroup_position_in_grid]],
852 uint simd_gid [[simdgroup_index_in_threadgroup]],
853 uint simd_lid [[thread_index_in_simdgroup]]) {
854 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
855 constexpr int num_simdgroups = 2;
856 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
857 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
858 constexpr int tn = 32 / pack_factor;
859 constexpr int block_size = SIMD_SIZE;
860
861 using W_T =
863 const device W_T* ws = (const device W_T*)w;
864
865 typedef float U;
866 typedef struct {
867 W_T wi[tn * bytes_per_pack];
868 } vec_w;
869
870 thread vec_w w_local;
871 thread U result[tn * pack_factor] = {0};
872 thread U scale = 1;
873 thread U bias = 0;
874 thread U x_local = 0;
875
876 // Adjust positions
877 const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
878 const int out_vec_size_g = out_vec_size / group_size;
879 int out_col = pack_factor * tn * (tid.x * num_simdgroups + simd_gid);
880 ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;
881 scales += out_col / group_size + simd_lid * out_vec_size_g;
882 biases += out_col / group_size + simd_lid * out_vec_size_g;
883 x += tid.y * in_vec_size + simd_lid;
884 y += tid.y * out_vec_size + out_col;
885
886 if (out_col >= out_vec_size) {
887 return;
888 }
889
890 // Loop over in_vec in blocks of block_size
891 int remaining = in_vec_size % block_size;
892 if (remaining == 0) {
893 for (int i = 0; i < in_vec_size; i += block_size) {
894 x_local = *x;
895 scale = *scales;
896 bias = *biases;
897 w_local = *((device vec_w*)ws);
899 (thread uint8_t*)&w_local, x_local, scale, bias, result);
900
901 x += block_size;
902 scales += block_size * out_vec_size_g;
903 biases += block_size * out_vec_size_g;
904 ws += block_size * out_vec_size_w;
905 }
906 } else {
907 for (int i = block_size; i < in_vec_size; i += block_size) {
908 x_local = *x;
909 scale = *scales;
910 bias = *biases;
911 w_local = *((device vec_w*)ws);
912
914 (thread uint8_t*)&w_local, x_local, scale, bias, result);
915
916 x += block_size;
917 scales += block_size * out_vec_size_g;
918 biases += block_size * out_vec_size_g;
919 ws += block_size * out_vec_size_w;
920 }
921 if (static_cast<int>(simd_lid) < remaining) {
922 x_local = *x;
923 scale = *scales;
924 bias = *biases;
925 w_local = *((device vec_w*)ws);
926 } else {
927 x_local = 0;
928 scale = 0;
929 bias = 0;
930 }
932 (thread uint8_t*)&w_local, x_local, scale, bias, result);
933 }
934
935// Accumulate in the simdgroup
936#pragma clang loop unroll(full)
937 for (int k = 0; k < tn * pack_factor; k++) {
938 result[k] = simd_sum(result[k]);
939 }
940
941 // Store the result
942 if (simd_lid == 0) {
943#pragma clang loop unroll(full)
944 for (int k = 0; k < tn * pack_factor; k++) {
945 y[k] = static_cast<T>(result[k]);
946 }
947 }
948}
949
950template <
951 typename T,
952 const int group_size,
953 const int bits,
954 const bool aligned_N,
955 const int BM = 32,
956 const int BK = 32,
957 const int BN = 32>
958METAL_FUNC void qmm_t_impl(
959 const device uint32_t* w,
960 const device T* scales,
961 const device T* biases,
962 const device T* x,
963 device T* y,
964 threadgroup T* Xs,
965 threadgroup T* Ws,
966 const constant int& K,
967 const constant int& N,
968 const constant int& M,
969 uint3 tid [[threadgroup_position_in_grid]],
970 uint lid [[thread_index_in_threadgroup]],
971 uint simd_gid [[simdgroup_index_in_threadgroup]],
972 uint simd_lid [[thread_index_in_simdgroup]]) {
973 static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
974 static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
975
976 (void)lid;
977
978 constexpr int WM = 2;
979 constexpr int WN = 2;
980 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
981 constexpr int BK_padded = (BK + 16 / sizeof(T));
982 constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
983
984 // Instantiate the appropriate BlockMMA and Loader
985 using mma_t = mlx::steel::
986 BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
987 using loader_x_t =
989 using loader_w_t = QuantizedBlockLoader<
990 T,
991 BN,
992 BK,
993 BK_padded,
994 1,
995 WM * WN * SIMD_SIZE,
996 group_size,
997 bits>;
998
999 // Set the block
1000 const int K_w = K * bytes_per_pack / pack_factor;
1001 const int K_g = K / group_size;
1002 const int y_row = tid.y * BM;
1003 const int y_col = tid.x * BN;
1004
1005 auto wl = (const device uint8_t*)w;
1006
1007 x += y_row * K;
1008 wl += y_col * K_w;
1009 scales += y_col * K_g;
1010 biases += y_col * K_g;
1011 y += y_row * N + y_col;
1012
1013 // Make the x loader and mma operation
1014 const short num_els = min(BM, M - y_row);
1015 const short num_outs = min(BN, N - y_col);
1016 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
1017 loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
1018 mma_t mma_op(simd_gid, simd_lid);
1019
1020 if (num_els < BM) {
1021 if (!aligned_N && num_outs < BN) {
1022 for (int k = 0; k < K; k += BK) {
1023 threadgroup_barrier(mem_flags::mem_threadgroup);
1024 loader_x.load_safe(short2(BK, num_els));
1025 loader_w.load_safe(short2(BK, num_outs));
1026 threadgroup_barrier(mem_flags::mem_threadgroup);
1027 mma_op.mma(Xs, Ws);
1028 loader_x.next();
1029 loader_w.next();
1030 }
1031 } else {
1032 for (int k = 0; k < K; k += BK) {
1033 threadgroup_barrier(mem_flags::mem_threadgroup);
1034 loader_x.load_safe(short2(BK, num_els));
1035 loader_w.load_unsafe();
1036 threadgroup_barrier(mem_flags::mem_threadgroup);
1037 mma_op.mma(Xs, Ws);
1038 loader_x.next();
1039 loader_w.next();
1040 }
1041 }
1042 } else {
1043 if (!aligned_N && num_outs < BN) {
1044 for (int k = 0; k < K; k += BK) {
1045 threadgroup_barrier(mem_flags::mem_threadgroup);
1046 loader_x.load_unsafe();
1047 loader_w.load_safe(short2(BK, num_outs));
1048 threadgroup_barrier(mem_flags::mem_threadgroup);
1049 mma_op.mma(Xs, Ws);
1050 loader_x.next();
1051 loader_w.next();
1052 }
1053 } else {
1054 for (int k = 0; k < K; k += BK) {
1055 threadgroup_barrier(mem_flags::mem_threadgroup);
1056 loader_x.load_unsafe();
1057 loader_w.load_unsafe();
1058 threadgroup_barrier(mem_flags::mem_threadgroup);
1059
1060 mma_op.mma(Xs, Ws);
1061 loader_x.next();
1062 loader_w.next();
1063 }
1064 }
1065 }
1066
1067 // Store results to device memory
1068 threadgroup_barrier(mem_flags::mem_threadgroup);
1069 if (num_els < BM || num_outs < BN) {
1070 mma_op.store_result_safe(y, N, short2(num_outs, num_els));
1071 } else {
1072 mma_op.store_result(y, N);
1073 }
1074}
1075
1076template <
1077 typename T,
1078 const int group_size,
1079 const int bits,
1080 const int BM = 32,
1081 const int BK = 32,
1082 const int BN = 32>
1083METAL_FUNC void qmm_n_impl(
1084 const device uint32_t* w,
1085 const device T* scales,
1086 const device T* biases,
1087 const device T* x,
1088 device T* y,
1089 threadgroup T* Xs,
1090 threadgroup T* Ws,
1091 const constant int& K,
1092 const constant int& N,
1093 const constant int& M,
1094 uint3 tid [[threadgroup_position_in_grid]],
1095 uint lid [[thread_index_in_threadgroup]],
1096 uint simd_gid [[simdgroup_index_in_threadgroup]],
1097 uint simd_lid [[thread_index_in_simdgroup]]) {
1098 static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
1099 static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
1100
1101 (void)lid;
1102
1103 constexpr int WM = 2;
1104 constexpr int WN = 2;
1105 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
1106 constexpr int BK_padded = (BK + 16 / sizeof(T));
1107 constexpr int BN_padded = (BN + 16 / sizeof(T));
1108 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
1109 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
1110
1111 // Instantiate the appropriate BlockMMA and Loader
1112 using mma_t = mlx::steel::
1113 BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
1114 using loader_x_t = mlx::steel::
1115 BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
1116 using loader_w_t = QuantizedBlockLoader<
1117 T,
1118 BK,
1119 BN,
1120 BN_padded,
1121 0,
1122 WM * WN * SIMD_SIZE,
1123 group_size,
1124 bits>;
1125
1126 auto wl = (const device uint8_t*)w;
1127
1128 // Set the block
1129 const int y_row = tid.y * BM;
1130 const int y_col = tid.x * BN;
1131 x += y_row * K;
1132 wl += y_col * bytes_per_pack / pack_factor;
1133 scales += y_col / group_size;
1134 biases += y_col / group_size;
1135 y += y_row * N + y_col;
1136
1137 // Make the x loader and mma operation
1138 const short num_els = min(BM, M - y_row);
1139 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
1140 loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid);
1141 mma_t mma_op(simd_gid, simd_lid);
1142
1143 if (num_els < BM) {
1144 if ((K % BK) != 0) {
1145 const int k_blocks = K / BK;
1146 for (int k = 0; k < k_blocks; k++) {
1147 threadgroup_barrier(mem_flags::mem_threadgroup);
1148 loader_x.load_safe(short2(BK, num_els));
1149 loader_w.load_unsafe();
1150 threadgroup_barrier(mem_flags::mem_threadgroup);
1151 mma_op.mma(Xs, Ws);
1152 loader_x.next();
1153 loader_w.next();
1154 }
1155 const short num_k = K - k_blocks * BK;
1156 threadgroup_barrier(mem_flags::mem_threadgroup);
1157 loader_x.load_safe(short2(num_k, num_els));
1158 loader_w.load_safe(short2(BN, num_k));
1159 threadgroup_barrier(mem_flags::mem_threadgroup);
1160 mma_op.mma(Xs, Ws);
1161 } else {
1162 for (int k = 0; k < K; k += BK) {
1163 threadgroup_barrier(mem_flags::mem_threadgroup);
1164 loader_x.load_safe(short2(BK, num_els));
1165 loader_w.load_unsafe();
1166 threadgroup_barrier(mem_flags::mem_threadgroup);
1167 mma_op.mma(Xs, Ws);
1168 loader_x.next();
1169 loader_w.next();
1170 }
1171 }
1172 } else {
1173 if ((K % BK) != 0) {
1174 const int k_blocks = K / BK;
1175 for (int k = 0; k < k_blocks; k++) {
1176 threadgroup_barrier(mem_flags::mem_threadgroup);
1177 loader_x.load_unsafe();
1178 loader_w.load_unsafe();
1179 threadgroup_barrier(mem_flags::mem_threadgroup);
1180 mma_op.mma(Xs, Ws);
1181 loader_x.next();
1182 loader_w.next();
1183 }
1184 const short num_k = K - k_blocks * BK;
1185 threadgroup_barrier(mem_flags::mem_threadgroup);
1186 loader_x.load_safe(short2(num_k, BM));
1187 loader_w.load_safe(short2(BN, num_k));
1188 threadgroup_barrier(mem_flags::mem_threadgroup);
1189 mma_op.mma(Xs, Ws);
1190 } else {
1191 for (int k = 0; k < K; k += BK) {
1192 threadgroup_barrier(mem_flags::mem_threadgroup);
1193 loader_x.load_unsafe();
1194 loader_w.load_unsafe();
1195 threadgroup_barrier(mem_flags::mem_threadgroup);
1196 mma_op.mma(Xs, Ws);
1197 loader_x.next();
1198 loader_w.next();
1199 }
1200 }
1201 }
1202
1203 // Store results to device memory
1204 threadgroup_barrier(mem_flags::mem_threadgroup);
1205 if (num_els < BM) {
1206 mma_op.store_result_safe(y, N, short2(BN, num_els));
1207 } else {
1208 mma_op.store_result(y, N);
1209 }
1210}
1211
1212template <typename T>
1214 const device T*& x,
1215 const device uint32_t*& w,
1216 const device T*& scales,
1217 const device T*& biases,
1218 device T*& y,
1219 int output_stride,
1220 const constant int& x_batch_ndims,
1221 const constant int* x_shape,
1222 const constant size_t* x_strides,
1223 const constant int& w_batch_ndims,
1224 const constant int* w_shape,
1225 const constant size_t* w_strides,
1226 const constant size_t* s_strides,
1227 const constant size_t* b_strides,
1228 uint3 tid [[threadgroup_position_in_grid]]) {
1229 // Set the input/output matrices
1230 uint32_t x_idx = tid.z;
1231 uint32_t w_idx = tid.z;
1232 if (x_batch_ndims == 1) {
1233 x += x_idx * x_strides[0];
1234 } else {
1235 x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1236 }
1237 if (w_batch_ndims == 1) {
1238 w += w_idx * w_strides[0];
1239 scales += w_idx * s_strides[0];
1240 biases += w_idx * b_strides[0];
1241 } else {
1242 ulong3 idx = elem_to_loc_broadcast(
1243 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1244 w += idx.x;
1245 scales += idx.y;
1246 biases += idx.z;
1247 }
1248 y += tid.z * output_stride;
1249}
1250
1251template <typename T>
1253 const device T*& x,
1254 const device uint32_t*& w,
1255 const device T*& scales,
1256 const device T*& biases,
1257 const device uint32_t* lhs_indices,
1258 const device uint32_t* rhs_indices,
1259 device T*& y,
1260 int output_stride,
1261 const constant int& batch_ndims,
1262 const constant int* batch_shape,
1263 const constant size_t* lhs_strides,
1264 const constant size_t* rhs_strides,
1265 const constant int& x_batch_ndims,
1266 const constant int* x_shape,
1267 const constant size_t* x_strides,
1268 const constant int& w_batch_ndims,
1269 const constant int* w_shape,
1270 const constant size_t* w_strides,
1271 const constant size_t* s_strides,
1272 const constant size_t* b_strides,
1273 uint3 tid [[threadgroup_position_in_grid]]) {
1274 // Set the input/output matrices
1275 uint32_t x_idx;
1276 uint32_t w_idx;
1277 if (batch_ndims == 1) {
1278 x_idx = lhs_indices[tid.z * lhs_strides[0]];
1279 w_idx = rhs_indices[tid.z * rhs_strides[0]];
1280 } else {
1281 ulong2 idx = elem_to_loc_broadcast(
1282 tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
1283 x_idx = lhs_indices[idx.x];
1284 w_idx = rhs_indices[idx.y];
1285 }
1286 if (x_batch_ndims == 1) {
1287 x += x_idx * x_strides[0];
1288 } else {
1289 x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1290 }
1291 if (w_batch_ndims == 1) {
1292 w += w_idx * w_strides[0];
1293 scales += w_idx * s_strides[0];
1294 biases += w_idx * b_strides[0];
1295 } else {
1296 ulong3 idx = elem_to_loc_broadcast(
1297 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1298 w += idx.x;
1299 scales += idx.y;
1300 biases += idx.z;
1301 }
1302 y += tid.z * output_stride;
1303}
1304
1305template <typename T, int group_size, int bits, int D, bool batched>
1306[[kernel]] void qmv_quad(
1307 const device uint32_t* w [[buffer(0)]],
1308 const device T* scales [[buffer(1)]],
1309 const device T* biases [[buffer(2)]],
1310 const device T* x [[buffer(3)]],
1311 device T* y [[buffer(4)]],
1312 const constant int& in_vec_size [[buffer(5)]],
1313 const constant int& out_vec_size [[buffer(6)]],
1314 const constant int& x_batch_ndims [[buffer(7)]],
1315 const constant int* x_shape [[buffer(8)]],
1316 const constant size_t* x_strides [[buffer(9)]],
1317 const constant int& w_batch_ndims [[buffer(10)]],
1318 const constant int* w_shape [[buffer(11)]],
1319 const constant size_t* w_strides [[buffer(12)]],
1320 const constant size_t* s_strides [[buffer(13)]],
1321 const constant size_t* b_strides [[buffer(14)]],
1322 uint3 tid [[threadgroup_position_in_grid]],
1323 uint quad_gid [[quadgroup_index_in_threadgroup]],
1324 uint quad_lid [[thread_index_in_quadgroup]]) {
1325 if (batched) {
1327 x,
1328 w,
1329 scales,
1330 biases,
1331 y,
1332 out_vec_size,
1333 x_batch_ndims,
1334 x_shape,
1335 x_strides,
1336 w_batch_ndims,
1337 w_shape,
1338 w_strides,
1339 s_strides,
1340 b_strides,
1341 tid);
1342 }
1344 w,
1345 scales,
1346 biases,
1347 x,
1348 y,
1349 in_vec_size,
1350 out_vec_size,
1351 tid,
1352 quad_gid,
1353 quad_lid);
1354}
1355
1356template <typename T, int group_size, int bits, bool batched>
1357[[kernel]] void qmv_fast(
1358 const device uint32_t* w [[buffer(0)]],
1359 const device T* scales [[buffer(1)]],
1360 const device T* biases [[buffer(2)]],
1361 const device T* x [[buffer(3)]],
1362 device T* y [[buffer(4)]],
1363 const constant int& in_vec_size [[buffer(5)]],
1364 const constant int& out_vec_size [[buffer(6)]],
1365 const constant int& x_batch_ndims [[buffer(7)]],
1366 const constant int* x_shape [[buffer(8)]],
1367 const constant size_t* x_strides [[buffer(9)]],
1368 const constant int& w_batch_ndims [[buffer(10)]],
1369 const constant int* w_shape [[buffer(11)]],
1370 const constant size_t* w_strides [[buffer(12)]],
1371 const constant size_t* s_strides [[buffer(13)]],
1372 const constant size_t* b_strides [[buffer(14)]],
1373 uint3 tid [[threadgroup_position_in_grid]],
1374 uint simd_gid [[simdgroup_index_in_threadgroup]],
1375 uint simd_lid [[thread_index_in_simdgroup]]) {
1376 if (batched) {
1378 x,
1379 w,
1380 scales,
1381 biases,
1382 y,
1383 out_vec_size,
1384 x_batch_ndims,
1385 x_shape,
1386 x_strides,
1387 w_batch_ndims,
1388 w_shape,
1389 w_strides,
1390 s_strides,
1391 b_strides,
1392 tid);
1393 }
1395 w,
1396 scales,
1397 biases,
1398 x,
1399 y,
1400 in_vec_size,
1401 out_vec_size,
1402 tid,
1403 simd_gid,
1404 simd_lid);
1405}
1406
1407template <typename T, const int group_size, const int bits, bool batched>
1408[[kernel]] void qmv(
1409 const device uint32_t* w [[buffer(0)]],
1410 const device T* scales [[buffer(1)]],
1411 const device T* biases [[buffer(2)]],
1412 const device T* x [[buffer(3)]],
1413 device T* y [[buffer(4)]],
1414 const constant int& in_vec_size [[buffer(5)]],
1415 const constant int& out_vec_size [[buffer(6)]],
1416 const constant int& x_batch_ndims [[buffer(7)]],
1417 const constant int* x_shape [[buffer(8)]],
1418 const constant size_t* x_strides [[buffer(9)]],
1419 const constant int& w_batch_ndims [[buffer(10)]],
1420 const constant int* w_shape [[buffer(11)]],
1421 const constant size_t* w_strides [[buffer(12)]],
1422 const constant size_t* s_strides [[buffer(13)]],
1423 const constant size_t* b_strides [[buffer(14)]],
1424 uint3 tid [[threadgroup_position_in_grid]],
1425 uint simd_gid [[simdgroup_index_in_threadgroup]],
1426 uint simd_lid [[thread_index_in_simdgroup]]) {
1427 if (batched) {
1429 x,
1430 w,
1431 scales,
1432 biases,
1433 y,
1434 out_vec_size,
1435 x_batch_ndims,
1436 x_shape,
1437 x_strides,
1438 w_batch_ndims,
1439 w_shape,
1440 w_strides,
1441 s_strides,
1442 b_strides,
1443 tid);
1444 }
1446 w,
1447 scales,
1448 biases,
1449 x,
1450 y,
1451 in_vec_size,
1452 out_vec_size,
1453 tid,
1454 simd_gid,
1455 simd_lid);
1456}
1457
1458template <typename T, const int group_size, const int bits, bool batched>
1459[[kernel]] void qvm(
1460 const device uint32_t* w [[buffer(0)]],
1461 const device T* scales [[buffer(1)]],
1462 const device T* biases [[buffer(2)]],
1463 const device T* x [[buffer(3)]],
1464 device T* y [[buffer(4)]],
1465 const constant int& in_vec_size [[buffer(5)]],
1466 const constant int& out_vec_size [[buffer(6)]],
1467 const constant int& x_batch_ndims [[buffer(7)]],
1468 const constant int* x_shape [[buffer(8)]],
1469 const constant size_t* x_strides [[buffer(9)]],
1470 const constant int& w_batch_ndims [[buffer(10)]],
1471 const constant int* w_shape [[buffer(11)]],
1472 const constant size_t* w_strides [[buffer(12)]],
1473 const constant size_t* s_strides [[buffer(13)]],
1474 const constant size_t* b_strides [[buffer(14)]],
1475 uint3 tid [[threadgroup_position_in_grid]],
1476 uint simd_gid [[simdgroup_index_in_threadgroup]],
1477 uint simd_lid [[thread_index_in_simdgroup]]) {
1478 if (batched) {
1480 x,
1481 w,
1482 scales,
1483 biases,
1484 y,
1485 out_vec_size,
1486 x_batch_ndims,
1487 x_shape,
1488 x_strides,
1489 w_batch_ndims,
1490 w_shape,
1491 w_strides,
1492 s_strides,
1493 b_strides,
1494 tid);
1495 }
1497 w,
1498 scales,
1499 biases,
1500 x,
1501 y,
1502 in_vec_size,
1503 out_vec_size,
1504 tid,
1505 simd_gid,
1506 simd_lid);
1507}
1508
1509template <typename T, const int group_size, const int bits, int split_k = 32>
1510[[kernel]] void qvm_split_k(
1511 const device uint32_t* w [[buffer(0)]],
1512 const device T* scales [[buffer(1)]],
1513 const device T* biases [[buffer(2)]],
1514 const device T* x [[buffer(3)]],
1515 device T* y [[buffer(4)]],
1516 const constant int& in_vec_size [[buffer(5)]],
1517 const constant int& out_vec_size [[buffer(6)]],
1518 const constant int& x_batch_ndims [[buffer(7)]],
1519 const constant int* x_shape [[buffer(8)]],
1520 const constant size_t* x_strides [[buffer(9)]],
1521 const constant int& w_batch_ndims [[buffer(10)]],
1522 const constant int* w_shape [[buffer(11)]],
1523 const constant size_t* w_strides [[buffer(12)]],
1524 const constant size_t* s_strides [[buffer(13)]],
1525 const constant size_t* b_strides [[buffer(14)]],
1526 const constant int& final_block_size [[buffer(15)]],
1527 uint3 tid [[threadgroup_position_in_grid]],
1528 uint simd_gid [[simdgroup_index_in_threadgroup]],
1529 uint simd_lid [[thread_index_in_simdgroup]]) {
1531 x,
1532 w,
1533 scales,
1534 biases,
1535 y,
1536 out_vec_size,
1537 x_batch_ndims,
1538 x_shape,
1539 x_strides,
1540 w_batch_ndims,
1541 w_shape,
1542 w_strides,
1543 s_strides,
1544 b_strides,
1545 tid);
1546
1547 // When (in_vec_size % split_k != 0) the final block needs to be smaller
1548 int in_vec_size_adj =
1549 tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
1550
1552 w,
1553 scales,
1554 biases,
1555 x,
1556 y,
1557 in_vec_size_adj,
1558 out_vec_size,
1559 tid,
1560 simd_gid,
1561 simd_lid);
1562}
1563
1564template <
1565 typename T,
1566 const int group_size,
1567 const int bits,
1568 const bool aligned_N,
1569 const bool batched,
1570 const int BM = 32,
1571 const int BK = 32,
1572 const int BN = 32>
1573[[kernel]] void qmm_t(
1574 const device uint32_t* w [[buffer(0)]],
1575 const device T* scales [[buffer(1)]],
1576 const device T* biases [[buffer(2)]],
1577 const device T* x [[buffer(3)]],
1578 device T* y [[buffer(4)]],
1579 const constant int& K [[buffer(5)]],
1580 const constant int& N [[buffer(6)]],
1581 const constant int& M [[buffer(7)]],
1582 const constant int& x_batch_ndims [[buffer(8)]],
1583 const constant int* x_shape [[buffer(9)]],
1584 const constant size_t* x_strides [[buffer(10)]],
1585 const constant int& w_batch_ndims [[buffer(11)]],
1586 const constant int* w_shape [[buffer(12)]],
1587 const constant size_t* w_strides [[buffer(13)]],
1588 const constant size_t* s_strides [[buffer(14)]],
1589 const constant size_t* b_strides [[buffer(15)]],
1590 uint3 tid [[threadgroup_position_in_grid]],
1591 uint lid [[thread_index_in_threadgroup]],
1592 uint simd_gid [[simdgroup_index_in_threadgroup]],
1593 uint simd_lid [[thread_index_in_simdgroup]]) {
1594 (void)lid;
1595
1596 constexpr int BK_padded = (BK + 16 / sizeof(T));
1597
1598 threadgroup T Xs[BM * BK_padded];
1599 threadgroup T Ws[BN * BK_padded];
1600
1601 if (batched) {
1603 x,
1604 w,
1605 scales,
1606 biases,
1607 y,
1608 M * N,
1609 x_batch_ndims,
1610 x_shape,
1611 x_strides,
1612 w_batch_ndims,
1613 w_shape,
1614 w_strides,
1615 s_strides,
1616 b_strides,
1617 tid);
1618 }
1620 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1621}
1622
1623template <
1624 typename T,
1625 const int group_size,
1626 const int bits,
1627 const bool batched,
1628 const int BM = 32,
1629 const int BK = 32,
1630 const int BN = 32>
1631[[kernel]] void qmm_n(
1632 const device uint32_t* w [[buffer(0)]],
1633 const device T* scales [[buffer(1)]],
1634 const device T* biases [[buffer(2)]],
1635 const device T* x [[buffer(3)]],
1636 device T* y [[buffer(4)]],
1637 const constant int& K [[buffer(5)]],
1638 const constant int& N [[buffer(6)]],
1639 const constant int& M [[buffer(7)]],
1640 const constant int& x_batch_ndims [[buffer(8)]],
1641 const constant int* x_shape [[buffer(9)]],
1642 const constant size_t* x_strides [[buffer(10)]],
1643 const constant int& w_batch_ndims [[buffer(11)]],
1644 const constant int* w_shape [[buffer(12)]],
1645 const constant size_t* w_strides [[buffer(13)]],
1646 const constant size_t* s_strides [[buffer(14)]],
1647 const constant size_t* b_strides [[buffer(15)]],
1648 uint3 tid [[threadgroup_position_in_grid]],
1649 uint lid [[thread_index_in_threadgroup]],
1650 uint simd_gid [[simdgroup_index_in_threadgroup]],
1651 uint simd_lid [[thread_index_in_simdgroup]]) {
1652 (void)lid;
1653
1654 constexpr int BK_padded = (BK + 16 / sizeof(T));
1655 constexpr int BN_padded = (BN + 16 / sizeof(T));
1656
1657 threadgroup T Xs[BM * BK_padded];
1658 threadgroup T Ws[BK * BN_padded];
1659
1660 if (batched) {
1662 x,
1663 w,
1664 scales,
1665 biases,
1666 y,
1667 M * N,
1668 x_batch_ndims,
1669 x_shape,
1670 x_strides,
1671 w_batch_ndims,
1672 w_shape,
1673 w_strides,
1674 s_strides,
1675 b_strides,
1676 tid);
1677 }
1678
1680 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1681}
1682
1683template <typename T, int group_size, int bits>
1684[[kernel]] void bs_qmv_fast(
1685 const device uint32_t* w [[buffer(0)]],
1686 const device T* scales [[buffer(1)]],
1687 const device T* biases [[buffer(2)]],
1688 const device T* x [[buffer(3)]],
1689 device T* y [[buffer(4)]],
1690 const constant int& in_vec_size [[buffer(5)]],
1691 const constant int& out_vec_size [[buffer(6)]],
1692 const constant int& x_batch_ndims [[buffer(7)]],
1693 const constant int* x_shape [[buffer(8)]],
1694 const constant size_t* x_strides [[buffer(9)]],
1695 const constant int& w_batch_ndims [[buffer(10)]],
1696 const constant int* w_shape [[buffer(11)]],
1697 const constant size_t* w_strides [[buffer(12)]],
1698 const constant size_t* s_strides [[buffer(13)]],
1699 const constant size_t* b_strides [[buffer(14)]],
1700 const constant int& batch_ndims [[buffer(15)]],
1701 const constant int* batch_shape [[buffer(16)]],
1702 const device uint32_t* lhs_indices [[buffer(17)]],
1703 const device uint32_t* rhs_indices [[buffer(18)]],
1704 const constant size_t* lhs_strides [[buffer(19)]],
1705 const constant size_t* rhs_strides [[buffer(20)]],
1706 uint3 tid [[threadgroup_position_in_grid]],
1707 uint simd_gid [[simdgroup_index_in_threadgroup]],
1708 uint simd_lid [[thread_index_in_simdgroup]]) {
1710 x,
1711 w,
1712 scales,
1713 biases,
1714 lhs_indices,
1715 rhs_indices,
1716 y,
1717 out_vec_size,
1718 batch_ndims,
1719 batch_shape,
1720 lhs_strides,
1721 rhs_strides,
1722 x_batch_ndims,
1723 x_shape,
1724 x_strides,
1725 w_batch_ndims,
1726 w_shape,
1727 w_strides,
1728 s_strides,
1729 b_strides,
1730 tid);
1732 w,
1733 scales,
1734 biases,
1735 x,
1736 y,
1737 in_vec_size,
1738 out_vec_size,
1739 tid,
1740 simd_gid,
1741 simd_lid);
1742}
1743
1744template <typename T, int group_size, int bits>
1745[[kernel]] void bs_qmv(
1746 const device uint32_t* w [[buffer(0)]],
1747 const device T* scales [[buffer(1)]],
1748 const device T* biases [[buffer(2)]],
1749 const device T* x [[buffer(3)]],
1750 device T* y [[buffer(4)]],
1751 const constant int& in_vec_size [[buffer(5)]],
1752 const constant int& out_vec_size [[buffer(6)]],
1753 const constant int& x_batch_ndims [[buffer(7)]],
1754 const constant int* x_shape [[buffer(8)]],
1755 const constant size_t* x_strides [[buffer(9)]],
1756 const constant int& w_batch_ndims [[buffer(10)]],
1757 const constant int* w_shape [[buffer(11)]],
1758 const constant size_t* w_strides [[buffer(12)]],
1759 const constant size_t* s_strides [[buffer(13)]],
1760 const constant size_t* b_strides [[buffer(14)]],
1761 const constant int& batch_ndims [[buffer(15)]],
1762 const constant int* batch_shape [[buffer(16)]],
1763 const device uint32_t* lhs_indices [[buffer(17)]],
1764 const device uint32_t* rhs_indices [[buffer(18)]],
1765 const constant size_t* lhs_strides [[buffer(19)]],
1766 const constant size_t* rhs_strides [[buffer(20)]],
1767 uint3 tid [[threadgroup_position_in_grid]],
1768 uint simd_gid [[simdgroup_index_in_threadgroup]],
1769 uint simd_lid [[thread_index_in_simdgroup]]) {
1771 x,
1772 w,
1773 scales,
1774 biases,
1775 lhs_indices,
1776 rhs_indices,
1777 y,
1778 out_vec_size,
1779 batch_ndims,
1780 batch_shape,
1781 lhs_strides,
1782 rhs_strides,
1783 x_batch_ndims,
1784 x_shape,
1785 x_strides,
1786 w_batch_ndims,
1787 w_shape,
1788 w_strides,
1789 s_strides,
1790 b_strides,
1791 tid);
1793 w,
1794 scales,
1795 biases,
1796 x,
1797 y,
1798 in_vec_size,
1799 out_vec_size,
1800 tid,
1801 simd_gid,
1802 simd_lid);
1803}
1804
1805template <typename T, int group_size, int bits>
1806[[kernel]] void bs_qvm(
1807 const device uint32_t* w [[buffer(0)]],
1808 const device T* scales [[buffer(1)]],
1809 const device T* biases [[buffer(2)]],
1810 const device T* x [[buffer(3)]],
1811 device T* y [[buffer(4)]],
1812 const constant int& in_vec_size [[buffer(5)]],
1813 const constant int& out_vec_size [[buffer(6)]],
1814 const constant int& x_batch_ndims [[buffer(7)]],
1815 const constant int* x_shape [[buffer(8)]],
1816 const constant size_t* x_strides [[buffer(9)]],
1817 const constant int& w_batch_ndims [[buffer(10)]],
1818 const constant int* w_shape [[buffer(11)]],
1819 const constant size_t* w_strides [[buffer(12)]],
1820 const constant size_t* s_strides [[buffer(13)]],
1821 const constant size_t* b_strides [[buffer(14)]],
1822 const constant int& batch_ndims [[buffer(15)]],
1823 const constant int* batch_shape [[buffer(16)]],
1824 const device uint32_t* lhs_indices [[buffer(17)]],
1825 const device uint32_t* rhs_indices [[buffer(18)]],
1826 const constant size_t* lhs_strides [[buffer(19)]],
1827 const constant size_t* rhs_strides [[buffer(20)]],
1828 uint3 tid [[threadgroup_position_in_grid]],
1829 uint simd_gid [[simdgroup_index_in_threadgroup]],
1830 uint simd_lid [[thread_index_in_simdgroup]]) {
1832 x,
1833 w,
1834 scales,
1835 biases,
1836 lhs_indices,
1837 rhs_indices,
1838 y,
1839 out_vec_size,
1840 batch_ndims,
1841 batch_shape,
1842 lhs_strides,
1843 rhs_strides,
1844 x_batch_ndims,
1845 x_shape,
1846 x_strides,
1847 w_batch_ndims,
1848 w_shape,
1849 w_strides,
1850 s_strides,
1851 b_strides,
1852 tid);
1854 w,
1855 scales,
1856 biases,
1857 x,
1858 y,
1859 in_vec_size,
1860 out_vec_size,
1861 tid,
1862 simd_gid,
1863 simd_lid);
1864}
1865
1866template <
1867 typename T,
1868 const int group_size,
1869 const int bits,
1870 const bool aligned_N,
1871 const int BM = 32,
1872 const int BK = 32,
1873 const int BN = 32>
1874[[kernel]] void bs_qmm_t(
1875 const device uint32_t* w [[buffer(0)]],
1876 const device T* scales [[buffer(1)]],
1877 const device T* biases [[buffer(2)]],
1878 const device T* x [[buffer(3)]],
1879 device T* y [[buffer(4)]],
1880 const constant int& K [[buffer(5)]],
1881 const constant int& N [[buffer(6)]],
1882 const constant int& M [[buffer(7)]],
1883 const constant int& x_batch_ndims [[buffer(8)]],
1884 const constant int* x_shape [[buffer(9)]],
1885 const constant size_t* x_strides [[buffer(10)]],
1886 const constant int& w_batch_ndims [[buffer(11)]],
1887 const constant int* w_shape [[buffer(12)]],
1888 const constant size_t* w_strides [[buffer(13)]],
1889 const constant size_t* s_strides [[buffer(14)]],
1890 const constant size_t* b_strides [[buffer(15)]],
1891 const constant int& batch_ndims [[buffer(16)]],
1892 const constant int* batch_shape [[buffer(17)]],
1893 const device uint32_t* lhs_indices [[buffer(18)]],
1894 const device uint32_t* rhs_indices [[buffer(19)]],
1895 const constant size_t* lhs_strides [[buffer(20)]],
1896 const constant size_t* rhs_strides [[buffer(21)]],
1897 uint3 tid [[threadgroup_position_in_grid]],
1898 uint lid [[thread_index_in_threadgroup]],
1899 uint simd_gid [[simdgroup_index_in_threadgroup]],
1900 uint simd_lid [[thread_index_in_simdgroup]]) {
1901 (void)lid;
1902
1903 constexpr int BK_padded = (BK + 16 / sizeof(T));
1904
1905 threadgroup T Xs[BM * BK_padded];
1906 threadgroup T Ws[BN * BK_padded];
1907
1909 x,
1910 w,
1911 scales,
1912 biases,
1913 lhs_indices,
1914 rhs_indices,
1915 y,
1916 M * N,
1917 batch_ndims,
1918 batch_shape,
1919 lhs_strides,
1920 rhs_strides,
1921 x_batch_ndims,
1922 x_shape,
1923 x_strides,
1924 w_batch_ndims,
1925 w_shape,
1926 w_strides,
1927 s_strides,
1928 b_strides,
1929 tid);
1931 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1932}
1933
1934template <
1935 typename T,
1936 const int group_size,
1937 const int bits,
1938 const int BM = 32,
1939 const int BK = 32,
1940 const int BN = 32>
1941[[kernel]] void bs_qmm_n(
1942 const device uint32_t* w [[buffer(0)]],
1943 const device T* scales [[buffer(1)]],
1944 const device T* biases [[buffer(2)]],
1945 const device T* x [[buffer(3)]],
1946 device T* y [[buffer(4)]],
1947 const constant int& K [[buffer(5)]],
1948 const constant int& N [[buffer(6)]],
1949 const constant int& M [[buffer(7)]],
1950 const constant int& x_batch_ndims [[buffer(8)]],
1951 const constant int* x_shape [[buffer(9)]],
1952 const constant size_t* x_strides [[buffer(10)]],
1953 const constant int& w_batch_ndims [[buffer(11)]],
1954 const constant int* w_shape [[buffer(12)]],
1955 const constant size_t* w_strides [[buffer(13)]],
1956 const constant size_t* s_strides [[buffer(14)]],
1957 const constant size_t* b_strides [[buffer(15)]],
1958 const constant int& batch_ndims [[buffer(16)]],
1959 const constant int* batch_shape [[buffer(17)]],
1960 const device uint32_t* lhs_indices [[buffer(18)]],
1961 const device uint32_t* rhs_indices [[buffer(19)]],
1962 const constant size_t* lhs_strides [[buffer(20)]],
1963 const constant size_t* rhs_strides [[buffer(21)]],
1964 uint3 tid [[threadgroup_position_in_grid]],
1965 uint lid [[thread_index_in_threadgroup]],
1966 uint simd_gid [[simdgroup_index_in_threadgroup]],
1967 uint simd_lid [[thread_index_in_simdgroup]]) {
1968 (void)lid;
1969
1970 constexpr int BK_padded = (BK + 16 / sizeof(T));
1971 constexpr int BN_padded = (BN + 16 / sizeof(T));
1972
1973 threadgroup T Xs[BM * BK_padded];
1974 threadgroup T Ws[BK * BN_padded];
1975
1977 x,
1978 w,
1979 scales,
1980 biases,
1981 lhs_indices,
1982 rhs_indices,
1983 y,
1984 M * N,
1985 batch_ndims,
1986 batch_shape,
1987 lhs_strides,
1988 rhs_strides,
1989 x_batch_ndims,
1990 x_shape,
1991 x_strides,
1992 w_batch_ndims,
1993 w_shape,
1994 w_strides,
1995 s_strides,
1996 b_strides,
1997 tid);
1999 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
2000}
2001
2002template <typename T, const int group_size, const int bits>
2003[[kernel]] void affine_quantize(
2004 const device T* w [[buffer(0)]],
2005 device uint8_t* out [[buffer(1)]],
2006 device T* scales [[buffer(2)]],
2007 device T* biases [[buffer(3)]],
2008 uint2 index [[thread_position_in_grid]],
2009 uint2 grid_dim [[threads_per_grid]]) {
2010 constexpr T eps = T(1e-7);
2011 constexpr int simd_size = 32;
2012 constexpr T n_bins = (1 << bits) - 1;
2013 constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
2014 constexpr int values_per_reduce = group_size / simd_size;
2015 constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
2016 constexpr int writes_per_pack =
2017 writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
2018 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
2019 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
2020
2021 static_assert(
2022 group_size % simd_size == 0,
2023 "Group size must be divisible by simd size.");
2024
2025 size_t offset = index.x + grid_dim.x * size_t(index.y);
2026 size_t in_index = offset * values_per_reduce;
2027 size_t out_index = power_of_2_bits
2028 ? offset * writes_per_pack
2029 : offset * bytes_per_pack / writes_per_reduce;
2030
2031 T w_thread[values_per_reduce];
2032 T w_min = Limits<T>::max;
2033 T w_max = 0;
2034
2035#pragma clang loop unroll(full)
2036 for (int i = 0; i < values_per_reduce; i++) {
2037 T val = w[in_index + i];
2038 w_thread[i] = val;
2039 w_min = min(w_min, val);
2040 w_max = max(w_max, val);
2041 }
2042
2043 w_min = simd_min(w_min);
2044 w_max = simd_max(w_max);
2045
2046 T scale = max((w_max - w_min) / n_bins, eps);
2047 bool side = abs(w_min) > abs(w_max);
2048 scale = side ? scale : -scale;
2049 T edge = side ? w_min : w_max;
2050 T q0 = round(edge / scale);
2051 bool at_zero = q0 == 0.0f;
2052 scale = at_zero ? scale : edge / q0;
2053 T bias = at_zero ? T(0) : edge;
2054
2055 // Write out the scales and biases
2056 size_t gindex = in_index / group_size;
2057 if (in_index % group_size == 0) {
2058 scales[gindex] = scale;
2059 biases[gindex] = bias;
2060 }
2061
2062 // We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
2063 uint32_t output = 0;
2064
2065#pragma clang loop unroll(full)
2066 for (int i = 0; i < values_per_reduce; i++) {
2067 uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
2068 if (bits == 8) {
2069 output = val;
2070 } else {
2071 output += val << (bits * (i % packs_per_int));
2072 }
2073
2074 if (packs_per_int < values_per_reduce &&
2075 i % packs_per_int == packs_per_int - 1) {
2076 out[out_index + i / packs_per_int] = output;
2077 output = 0;
2078 } else {
2079#pragma clang loop unroll(full)
2080 for (int j = 1; j < writes_per_reduce; j++) {
2081 uint8_t sval = simd_shuffle_down(val, j);
2082 output += sval << (bits * (j * values_per_reduce + i));
2083 }
2084 }
2085 }
2086 if (bits == 3 || bits == 6) {
2087 if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) {
2088 out[out_index] = output & 0xff;
2089 out[out_index + 1] = (output & 0xff00) >> 8;
2090 out[out_index + 2] = (output & 0xff0000) >> 16;
2091 }
2092 } else {
2093 if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
2094 out[out_index / writes_per_reduce] = output;
2095 }
2096 }
2097}
2098
2099template <typename T, const int group_size, const int bits>
2100[[kernel]] void affine_dequantize(
2101 const device uint8_t* w [[buffer(0)]],
2102 const device T* scales [[buffer(1)]],
2103 const device T* biases [[buffer(2)]],
2104 device T* out [[buffer(3)]],
2105 uint2 index [[thread_position_in_grid]],
2106 uint2 grid_dim [[threads_per_grid]]) {
2107 constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
2108 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
2109 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
2110
2111 size_t offset = index.x + grid_dim.x * size_t(index.y);
2112 size_t oindex = offset * packs_per_int;
2113 size_t gindex = oindex / group_size;
2114 T scale = scales[gindex];
2115 T bias = biases[gindex];
2116
2117 out += oindex;
2118
2119 if (bits == 3) {
2120 w += offset * bytes_per_pack;
2121 out[0] = (w[0] & 0x7) * scale + bias;
2122 out[1] = ((w[0] & 0x38) >> 3) * scale + bias;
2123 out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
2124 out[3] = ((w[1] & 0xe) >> 1) * scale + bias;
2125 out[4] = ((w[1] & 0x70) >> 4) * scale + bias;
2126 out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
2127 out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
2128 out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
2129
2130 } else if (bits == 6) {
2131 w += offset * bytes_per_pack;
2132 out[0] = (w[0] & 0x3f) * scale + bias;
2133 out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
2134 out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
2135 out[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
2136 } else {
2137 uint val = w[offset];
2138#pragma clang loop unroll(full)
2139 for (int i = 0; i < packs_per_int; i++) {
2140 uint8_t d;
2141 if (bits == 2) {
2142 d = (val >> (bits * i)) & 0x03;
2143 } else if (bits == 4) {
2144 d = (val >> (bits * i)) & 0x0f;
2145 } else if (bits == 8) {
2146 d = val;
2147 }
2148 out[i] = scale * d + bias;
2149 }
2150 }
2151}
static constant constexpr const uint8_t simd_size
Definition ops.h:22
METAL_FUNC ulong2 elem_to_loc_broadcast(uint elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
Definition utils.h:7
METAL_FUNC IdxT elem_to_loc(uint elem, constant const int *shape, constant const StrideT *strides, int ndim)
Definition utils.h:93
Definition bf16_math.h:226
METAL_FUNC bfloat16_t simd_max(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t round(bfloat16_t x)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t simd_sum(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t abs(bfloat16_t x)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t simd_min(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta)
Definition bf16_math.h:377
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
#define MLX_MTL_CONST
Definition quantized.h:8
U qdot_safe(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum, int N)
Definition quantized.h:225
METAL_FUNC void qmm_n_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1083
METAL_FUNC void qvm_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const int in_vec_size, const int out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:843
void bs_qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1941
void qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1631
void affine_quantize(const device T *w, device uint8_t *out, device T *scales, device T *biases, uint2 index, uint2 grid_dim)
Definition quantized.h:2003
void bs_qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1684
void affine_dequantize(const device uint8_t *w, const device T *scales, const device T *biases, device T *out, uint2 index, uint2 grid_dim)
Definition quantized.h:2100
static constant constexpr const int SIMD_SIZE
Definition quantized.h:10
void qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1408
void bs_qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1806
void qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1357
void qmv_quad(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:1306
static constant constexpr const int QUAD_SIZE
Definition quantized.h:11
U load_vector(const device T *x, thread U *x_thread)
Definition quantized.h:14
METAL_FUNC void qmv_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:688
U load_vector_safe(const device T *x, thread U *x_thread, int N)
Definition quantized.h:77
void bs_qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1874
U qdot(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum)
Definition quantized.h:145
void qvm_split_k(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &final_block_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1510
METAL_FUNC void qmv_fast_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:620
void qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1573
METAL_FUNC void adjust_matrix_offsets(const device T *&x, const device uint32_t *&w, const device T *&scales, const device T *&biases, device T *&y, int output_stride, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid)
Definition quantized.h:1213
void bs_qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1745
METAL_FUNC void qmv_quad_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:563
void qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1459
void qouter(const thread uint8_t *w, U x, U scale, U bias, thread U *result)
Definition quantized.h:307
void dequantize(const device uint8_t *w, U scale, U bias, threadgroup U *w_local)
Definition quantized.h:372
METAL_FUNC void qmm_t_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:958
U type
Definition utils.h:428
Definition utils.h:23
Definition quantized.h:443
const int group_stride
Definition quantized.h:464
static constant constexpr const short BCOLS_PACKED
Definition quantized.h:456
const device T * biases
Definition quantized.h:473
short group_step_cnt
Definition quantized.h:463
static constant constexpr const short group_steps
Definition quantized.h:459
const short thread_idx
Definition quantized.h:466
QuantizedBlockLoader(const device uint8_t *src_, const device T *scales_, const device T *biases_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition quantized.h:475
const device T * scales
Definition quantized.h:472
static constant constexpr const short n_reads
Definition quantized.h:457
void next()
Definition quantized.h:541
void load_safe(short2 src_tile_dim) const
Definition quantized.h:511
const int src_ld
Definition quantized.h:461
const short bi
Definition quantized.h:467
void load_unsafe() const
Definition quantized.h:498
static constant constexpr const short pack_factor
Definition quantized.h:454
threadgroup T * dst
Definition quantized.h:470
const device uint8_t * src
Definition quantized.h:471
const int tile_stride
Definition quantized.h:462
static constant constexpr const short bytes_per_pack
Definition quantized.h:455
const short bj
Definition quantized.h:468
Definition loader.h:25