MLX
Loading...
Searching...
No Matches
quantized.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#include <metal_simdgroup>
4#include <metal_stdlib>
5
6using namespace metal;
7
8#define MLX_MTL_CONST static constant constexpr const
9
12
13template <typename T, typename U, int values_per_thread, int bits>
14inline U load_vector(const device T* x, thread U* x_thread) {
15 static_assert(
16 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
17 "Template undefined for bits not in {2, 3, 4, 6, 8}");
18
19 U sum = 0;
20
21 if (bits == 2) {
22 for (int i = 0; i < values_per_thread; i += 4) {
23 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
24 x_thread[i] = x[i];
25 x_thread[i + 1] = x[i + 1] / 4.0f;
26 x_thread[i + 2] = x[i + 2] / 16.0f;
27 x_thread[i + 3] = x[i + 3] / 64.0f;
28 }
29 }
30
31 else if (bits == 3) {
32 for (int i = 0; i < values_per_thread; i += 8) {
33 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
34 x[i + 6] + x[i + 7];
35 x_thread[i] = x[i];
36 x_thread[i + 1] = x[i + 1] / 8.0f;
37 x_thread[i + 2] = x[i + 2] / 64.0f;
38 x_thread[i + 3] = x[i + 3] / 2.0f;
39 x_thread[i + 4] = x[i + 4] / 16.0f;
40 x_thread[i + 5] = x[i + 5] / 128.0f;
41 x_thread[i + 6] = x[i + 6] / 4.0f;
42 x_thread[i + 7] = x[i + 7] / 32.0f;
43 }
44 }
45
46 else if (bits == 4) {
47 for (int i = 0; i < values_per_thread; i += 4) {
48 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
49 x_thread[i] = x[i];
50 x_thread[i + 1] = x[i + 1] / 16.0f;
51 x_thread[i + 2] = x[i + 2] / 256.0f;
52 x_thread[i + 3] = x[i + 3] / 4096.0f;
53 }
54 }
55
56 else if (bits == 6) {
57 for (int i = 0; i < values_per_thread; i += 4) {
58 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
59 x_thread[i] = x[i];
60 x_thread[i + 1] = x[i + 1] / 64.0f;
61 x_thread[i + 2] = x[i + 2] / 16.0f;
62 x_thread[i + 3] = x[i + 3] / 4.0f;
63 }
64 }
65
66 else if (bits == 8) {
67 for (int i = 0; i < values_per_thread; i++) {
68 sum += x[i];
69 x_thread[i] = x[i];
70 }
71 }
72
73 return sum;
74}
75
76template <typename T, typename U, int values_per_thread, int bits>
77inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
78 static_assert(
79 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
80 "Template undefined for bits not in {2, 3, 4, 6, 8}");
81
82 U sum = 0;
83
84 if (bits == 2) {
85 for (int i = 0; i < N; i += 4) {
86 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
87 x_thread[i] = x[i];
88 x_thread[i + 1] = x[i + 1] / 4.0f;
89 x_thread[i + 2] = x[i + 2] / 16.0f;
90 x_thread[i + 3] = x[i + 3] / 64.0f;
91 }
92 }
93
94 else if (bits == 3) {
95 for (int i = 0; i < N; i += 8) {
96 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
97 x[i + 6] + x[i + 7];
98
99 x_thread[i] = x[i];
100 x_thread[i + 1] = x[i + 1] / 8.0f;
101 x_thread[i + 2] = x[i + 2] / 64.0f;
102 x_thread[i + 3] = x[i + 3] / 2.0f;
103 x_thread[i + 4] = x[i + 4] / 16.0f;
104 x_thread[i + 5] = x[i + 5] / 128.0f;
105 x_thread[i + 6] = x[i + 6] / 4.0f;
106 x_thread[i + 7] = x[i + 7] / 32.0f;
107 }
108 }
109
110 else if (bits == 4) {
111 for (int i = 0; i < N; i += 4) {
112 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
113 x_thread[i] = x[i];
114 x_thread[i + 1] = x[i + 1] / 16.0f;
115 x_thread[i + 2] = x[i + 2] / 256.0f;
116 x_thread[i + 3] = x[i + 3] / 4096.0f;
117 }
118 }
119
120 else if (bits == 6) {
121 for (int i = 0; i < N; i += 4) {
122 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
123 x_thread[i] = x[i];
124 x_thread[i + 1] = x[i + 1] / 64.0f;
125 x_thread[i + 2] = x[i + 2] / 16.0f;
126 x_thread[i + 3] = x[i + 3] / 4.0f;
127 }
128 }
129
130 else if (bits == 8) {
131 for (int i = 0; i < N; i++) {
132 sum += x[i];
133 x_thread[i] = x[i];
134 }
135 }
136
137 for (int i = N; i < values_per_thread; i++) {
138 x_thread[i] = 0;
139 }
140
141 return sum;
142}
143
144template <typename U, int values_per_thread, int bits>
145inline U qdot(
146 const device uint8_t* w,
147 const thread U* x_thread,
148 U scale,
149 U bias,
150 U sum) {
151 static_assert(
152 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
153 "Template undefined for bits not in {2, 3, 4, 6, 8}");
154
155 U accum = 0;
156
157 if (bits == 2) {
158 for (int i = 0; i < (values_per_thread / 4); i++) {
159 accum +=
160 (x_thread[4 * i] * (w[i] & 0x03) +
161 x_thread[4 * i + 1] * (w[i] & 0x0c) +
162 x_thread[4 * i + 2] * (w[i] & 0x30) +
163 x_thread[4 * i + 3] * (w[i] & 0xc0));
164 }
165 }
166
167 else if (bits == 3) {
168 for (int i = 0; i < (values_per_thread / 8); i++) {
169 x_thread += 8 * i;
170 w += 3 * i;
171
172 accum += (w[0] & 0x07) * x_thread[0];
173 accum += (w[0] & 0x38) * x_thread[1];
174 accum += (w[0] & 0xc0) * x_thread[2];
175 accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
176
177 accum += (w[1] & 0x0e) * x_thread[3];
178 accum += (w[1] & 0x70) * x_thread[4];
179 accum += (w[1] & 0x80) * x_thread[5];
180 accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
181
182 accum += (w[2] & 0x1c) * x_thread[6];
183 accum += (w[2] & 0xe0) * x_thread[7];
184 }
185 }
186
187 else if (bits == 4) {
188 const device uint16_t* ws = (const device uint16_t*)w;
189 for (int i = 0; i < (values_per_thread / 4); i++) {
190 accum +=
191 (x_thread[4 * i] * (ws[i] & 0x000f) +
192 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
193 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
194 x_thread[4 * i + 3] * (ws[i] & 0xf000));
195 }
196 }
197
198 else if (bits == 6) {
199 for (int i = 0; i < (values_per_thread / 4); i++) {
200 x_thread += 4 * i;
201 w += 3 * i;
202
203 accum += (w[0] & 0x3f) * x_thread[0];
204
205 accum += (w[0] & 0xc0) * x_thread[1];
206 accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
207
208 accum += (w[1] & 0xf0) * x_thread[2];
209 accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
210
211 accum += (w[2] & 0xfc) * x_thread[3];
212 }
213 }
214
215 else if (bits == 8) {
216 for (int i = 0; i < values_per_thread; i++) {
217 accum += x_thread[i] * w[i];
218 }
219 }
220
221 return scale * accum + sum * bias;
222}
223
224template <typename U, int values_per_thread, int bits>
225inline U qdot_safe(
226 const device uint8_t* w,
227 const thread U* x_thread,
228 U scale,
229 U bias,
230 U sum,
231 int N) {
232 static_assert(
233 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
234 "Template undefined for bits not in {2, 3, 4, 6, 8}");
235
236 U accum = 0;
237
238 if (bits == 2) {
239 for (int i = 0; i < (N / 4); i++) {
240 accum +=
241 (x_thread[4 * i] * (w[i] & 0x03) +
242 x_thread[4 * i + 1] * (w[i] & 0x0c) +
243 x_thread[4 * i + 2] * (w[i] & 0x30) +
244 x_thread[4 * i + 3] * (w[i] & 0xc0));
245 }
246 }
247
248 else if (bits == 3) {
249 for (int i = 0; i < (N / 8); i++) {
250 x_thread += 8 * i;
251 w += 3 * i;
252
253 accum += (w[0] & 0x07) * x_thread[0];
254 accum += (w[0] & 0x38) * x_thread[1];
255 accum += (w[0] & 0xc0) * x_thread[2];
256 accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
257
258 accum += (w[1] & 0x0e) * x_thread[3];
259 accum += (w[1] & 0x70) * x_thread[4];
260 accum += (w[1] & 0x80) * x_thread[5];
261 accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
262
263 accum += (w[2] & 0x1c) * x_thread[6];
264 accum += (w[2] & 0xe0) * x_thread[7];
265 }
266 }
267
268 else if (bits == 4) {
269 const device uint16_t* ws = (const device uint16_t*)w;
270 for (int i = 0; i < (N / 4); i++) {
271 accum +=
272 (x_thread[4 * i] * (ws[i] & 0x000f) +
273 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
274 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
275 x_thread[4 * i + 3] * (ws[i] & 0xf000));
276 }
277 }
278
279 else if (bits == 6) {
280 for (int i = 0; i < (N / 4); i++) {
281 x_thread += 4 * i;
282 w += 3 * i;
283
284 accum += (w[0] & 0x3f) * x_thread[0];
285
286 accum += (w[0] & 0xc0) * x_thread[1];
287 accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
288
289 accum += (w[1] & 0xf0) * x_thread[2];
290 accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
291
292 accum += (w[2] & 0xfc) * x_thread[3];
293 }
294 }
295
296 else if (bits == 8) {
297 for (int i = 0; i < N; i++) {
298 accum += x_thread[i] * w[i];
299 }
300 }
301
302 return scale * accum + sum * bias;
303}
304
305template <typename U, int values_per_thread, int bits>
306inline void
307qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
308 static_assert(
309 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
310 "Template undefined for bits not in {2, 3, 4, 6, 8}");
311
312 if (bits == 2) {
313 U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
314 for (int i = 0; i < (values_per_thread / 4); i++) {
315 result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
316 result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
317 result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
318 result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
319 }
320 }
321
322 else if (bits == 3) {
323 for (int i = 0; i < (values_per_thread / 8); i++) {
324 uint8_t w0 = w[3 * i];
325 uint8_t w1 = w[3 * i + 1];
326 uint8_t w2 = w[3 * i + 2];
327
328 result[8 * i] += x * ((w0 & 0x7) * scale + bias);
329 result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);
330 result[8 * i + 2] +=
331 x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);
332 result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);
333 result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);
334 result[8 * i + 5] +=
335 x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);
336 result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);
337 result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);
338 }
339 }
340
341 else if (bits == 4) {
342 U s[2] = {scale, scale / 16.0f};
343 for (int i = 0; i < (values_per_thread / 2); i++) {
344 result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
345 result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
346 }
347
348 } else if (bits == 6) {
349 for (int i = 0; i < (values_per_thread / 4); i++) {
350 uint8_t w0 = w[3 * i];
351 uint8_t w1 = w[3 * i + 1];
352 uint8_t w2 = w[3 * i + 2];
353
354 result[4 * i] += x * ((w0 & 0x3f) * scale + bias);
355 result[4 * i + 1] +=
356 x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);
357 result[4 * i + 2] +=
358 x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);
359 result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);
360 }
361 }
362
363 else if (bits == 8) {
364 for (int i = 0; i < values_per_thread; i++) {
365 result[i] += x * (scale * w[i] + bias);
366 }
367 }
368}
369
370template <typename U, int N, int bits>
371inline void
372dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
373 static_assert(
374 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
375 "Template undefined for bits not in {2, 3, 4, 6, 8}");
376
377 if (bits == 2) {
378 U s[4] = {
379 scale,
380 scale / static_cast<U>(4.0f),
381 scale / static_cast<U>(16.0f),
382 scale / static_cast<U>(64.0f)};
383 for (int i = 0; i < (N / 4); i++) {
384 w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
385 w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
386 w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
387 w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
388 }
389 }
390
391 else if (bits == 3) {
392 for (int i = 0; i < (N / 8); i++) {
393 w_local += 8 * i;
394 w += 3 * i;
395
396 w_local[0] = (w[0] & 0x7) * scale + bias;
397 w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;
398 w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
399 w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;
400 w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;
401 w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
402 w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
403 w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
404 }
405 }
406
407 else if (bits == 4) {
408 U s[2] = {scale, scale / static_cast<U>(16.0f)};
409 for (int i = 0; i < (N / 2); i++) {
410 w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
411 w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
412 }
413 }
414
415 else if (bits == 6) {
416 for (int i = 0; i < (N / 4); i++) {
417 w_local += 4 * i;
418 w += 3 * i;
419
420 w_local[0] = (w[0] & 0x3f) * scale + bias;
421 w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
422 w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
423 w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
424 }
425 }
426
427 else if (bits == 8) {
428 for (int i = 0; i < N; i++) {
429 w_local[i] = scale * w[i] + bias;
430 }
431 }
432}
433
434template <
435 typename T,
436 short BROWS,
437 short BCOLS,
438 short dst_ld,
439 short reduction_dim,
440 short tgp_size,
441 short group_size,
442 short bits>
444 static_assert(
445 BCOLS <= group_size,
446 "The group size should be larger than the columns");
447 static_assert(
448 group_size % BCOLS == 0,
449 "The group size should be divisible by the columns");
450 static_assert(
451 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
452 "Template undefined for bits not in {2, 3, 4, 6, 8}");
453
454 MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
455 MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
458 (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
459 MLX_MTL_CONST short group_steps = group_size / BCOLS;
460
461 const int src_ld;
462 const int tile_stride;
464 const int group_stride;
465
466 const short thread_idx;
467 const short bi;
468 const short bj;
469
470 threadgroup T* dst;
471 const device uint8_t* src;
472 const device T* scales;
473 const device T* biases;
474
476 const device uint8_t* src_,
477 const device T* scales_,
478 const device T* biases_,
479 const int src_ld_,
480 threadgroup T* dst_,
481 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
482 ushort simd_lane_id [[thread_index_in_simdgroup]])
483 : src_ld(src_ld_),
485 reduction_dim ? BCOLS_PACKED * bytes_per_pack
486 : BROWS * src_ld * bytes_per_pack / pack_factor),
488 group_stride(BROWS * src_ld / group_size),
489 thread_idx(simd_group_id * 32 + simd_lane_id),
492 dst(dst_ + bi * dst_ld + bj * pack_factor),
493 src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
495 scales(scales_ + bi * src_ld / group_size),
496 biases(biases_ + bi * src_ld / group_size) {}
497
498 void load_unsafe() const {
499 if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
500 return;
501 }
502
503 T scale = *scales;
504 T bias = *biases;
505 for (int i = 0; i < n_reads; i++) {
507 src + i * bytes_per_pack, scale, bias, dst + i * pack_factor);
508 }
509 }
510
511 void load_safe(short2 src_tile_dim) const {
512 if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
513 return;
514 }
515
516 if (reduction_dim == 1 && bi >= src_tile_dim.y) {
517 for (int i = 0; i < n_reads * pack_factor; i++) {
518 dst[i] = T(0);
519 }
520 return;
521 }
522
523 if (reduction_dim == 0 && bi >= src_tile_dim.x) {
524 for (int i = 0; i < n_reads * pack_factor; i++) {
525 dst[i] = T(0);
526 }
527 return;
528 }
529
530 T scale = *scales;
531 T bias = *biases;
532 for (int i = 0; i < n_reads; i++) {
534 (device uint8_t*)(src + i * bytes_per_pack),
535 scale,
536 bias,
537 dst + i * pack_factor);
538 }
539 }
540
541 void next() {
542 src += tile_stride;
543 if (reduction_dim == 1) {
544 if (group_steps > 1) {
547 group_step_cnt = 0;
548 scales++;
549 biases++;
550 }
551 } else {
552 scales++;
553 biases++;
554 }
555 } else {
558 }
559 }
560};
561
562template <typename T, int group_size, int bits, int D>
563METAL_FUNC void qmv_quad_impl(
564 const device uint32_t* w,
565 const device T* scales,
566 const device T* biases,
567 const device T* x,
568 device T* y,
569 constant int& in_vec_size,
570 const constant int& out_vec_size,
571 uint3 tid [[threadgroup_position_in_grid]],
572 uint quad_gid [[quadgroup_index_in_threadgroup]],
573 uint quad_lid [[thread_index_in_quadgroup]]) {
574 constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
575 constexpr int pack_factor = 32 / bits;
576 constexpr int values_per_thread = D / QUAD_SIZE;
577 constexpr int packs_per_thread = values_per_thread / pack_factor;
578 constexpr int scale_step_per_thread = group_size / values_per_thread;
579 constexpr int results_per_quadgroup = 8;
580
581 typedef float U;
582
583 thread U x_thread[values_per_thread];
584 thread U result[results_per_quadgroup] = {0};
585
586 // Adjust positions
587 const int in_vec_size_w = in_vec_size / pack_factor;
588 const int in_vec_size_g = in_vec_size / group_size;
589 const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
590
591 w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
592 scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
593 biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
594 x += tid.y * in_vec_size + quad_lid * values_per_thread;
595 y += tid.y * out_vec_size + out_row;
596
598
599 for (int row = 0; row < results_per_quadgroup; row++) {
600 auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
601 const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
602 const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
603
604 U s = sl[0];
605 U b = bl[0];
606 if (row * quads_per_simd + out_row < out_vec_size) {
607 result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
608 }
609 }
610
611 for (int row = 0; row < results_per_quadgroup; row++) {
612 result[row] = quad_sum(result[row]);
613 if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
614 y[row * quads_per_simd] = static_cast<T>(result[row]);
615 }
616 }
617}
618
619template <typename T, int group_size, int bits>
620METAL_FUNC void qmv_fast_impl(
621 const device uint32_t* w,
622 const device T* scales,
623 const device T* biases,
624 const device T* x,
625 device T* y,
626 const constant int& in_vec_size,
627 const constant int& out_vec_size,
628 uint3 tid [[threadgroup_position_in_grid]],
629 uint simd_gid [[simdgroup_index_in_threadgroup]],
630 uint simd_lid [[thread_index_in_simdgroup]]) {
631 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
632 constexpr int packs_per_thread = bits == 2 ? 1 : 2;
633 constexpr int num_simdgroups = 2;
634 constexpr int results_per_simdgroup = 4;
635 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
636 constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
637 constexpr int values_per_thread = pack_factor * packs_per_thread;
638 constexpr int block_size = values_per_thread * SIMD_SIZE;
639 constexpr int scale_step_per_thread = group_size / values_per_thread;
640
641 const device uint8_t* ws = (const device uint8_t*)w;
642
643 typedef float U;
644
645 thread U x_thread[values_per_thread];
646 thread U result[results_per_simdgroup] = {0};
647
648 // Adjust positions
649 const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
650 const int in_vec_size_g = in_vec_size / group_size;
651 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
652 simd_gid * results_per_simdgroup;
653
654 ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
655 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
656 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
657 x += tid.y * in_vec_size + simd_lid * values_per_thread;
658 y += tid.y * out_vec_size + out_row;
659
660 for (int k = 0; k < in_vec_size; k += block_size) {
662
663 for (int row = 0; row < results_per_simdgroup; row++) {
664 auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
665 const device T* sl = scales + row * in_vec_size_g;
666 const device T* bl = biases + row * in_vec_size_g;
667
668 U s = sl[0];
669 U b = bl[0];
670 result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
671 }
672
673 ws += block_size * bytes_per_pack / pack_factor;
674 scales += block_size / group_size;
675 biases += block_size / group_size;
676 x += block_size;
677 }
678
679 for (int row = 0; row < results_per_simdgroup; row++) {
680 result[row] = simd_sum(result[row]);
681 if (simd_lid == 0) {
682 y[row] = static_cast<T>(result[row]);
683 }
684 }
685}
686
687template <typename T, int group_size, int bits>
688METAL_FUNC void qmv_impl(
689 const device uint32_t* w,
690 const device T* scales,
691 const device T* biases,
692 const device T* x,
693 device T* y,
694 const constant int& in_vec_size,
695 const constant int& out_vec_size,
696 uint3 tid [[threadgroup_position_in_grid]],
697 uint simd_gid [[simdgroup_index_in_threadgroup]],
698 uint simd_lid [[thread_index_in_simdgroup]]) {
699 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
700 constexpr int num_simdgroups = 2;
701 constexpr int results_per_simdgroup = 4;
702 constexpr int packs_per_thread = 1;
703 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
704 constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
705 constexpr int values_per_thread = pack_factor * packs_per_thread;
706 constexpr int block_size = values_per_thread * SIMD_SIZE;
707 constexpr int scale_step_per_thread = group_size / values_per_thread;
708
709 const device uint8_t* ws = (const device uint8_t*)w;
710
711 typedef float U;
712
713 thread U x_thread[values_per_thread];
714 thread U result[results_per_simdgroup] = {0};
715
716 // Adjust positions
717 const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
718 const int in_vec_size_g = in_vec_size / group_size;
719 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
720 simd_gid * results_per_simdgroup;
721 const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
722
723 if (out_row >= out_vec_size) {
724 return;
725 }
726
727 // In this case we need to properly guard all our reads because there isn't
728 // even 1 tile in the matrix
729 if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
730 ws +=
731 out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
732 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
733 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
734 x += tid.y * in_vec_size + simd_lid * values_per_thread;
735 y += tid.y * out_vec_size + out_row;
736
737 int k = 0;
738 for (; k < in_vec_size - block_size; k += block_size) {
740
741 for (int row = 0; out_row + row < out_vec_size; row++) {
742 auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
743 const device T* sl = scales + row * in_vec_size_g;
744 const device T* bl = biases + row * in_vec_size_g;
745
746 U s = sl[0];
747 U b = bl[0];
748 result[row] +=
749 qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
750 }
751
752 ws += block_size * bytes_per_pack / pack_factor;
753 scales += block_size / group_size;
754 biases += block_size / group_size;
755 x += block_size;
756 }
757 const int remaining = clamp(
758 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
759 0,
760 values_per_thread);
761 if (remaining > 0) {
763 x, x_thread, remaining);
764
765 for (int row = 0; out_row + row < out_vec_size; row++) {
766 auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
767 const device T* sl = scales + row * in_vec_size_g;
768 const device T* bl = biases + row * in_vec_size_g;
769
770 U s = sl[0];
771 U b = bl[0];
772 result[row] +=
773 qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
774 }
775 }
776
777 for (int row = 0; out_row + row < out_vec_size; row++) {
778 result[row] = simd_sum(result[row]);
779 if (simd_lid == 0) {
780 y[row] = static_cast<T>(result[row]);
781 }
782 }
783 }
784
785 // In this case the last tile is moved back to redo some output values
786 else {
787 ws += used_out_row * in_vec_size_w +
788 simd_lid * packs_per_thread * bytes_per_pack;
789 scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
790 biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
791 x += tid.y * in_vec_size + simd_lid * values_per_thread;
792 y += tid.y * out_vec_size + used_out_row;
793
794 int k = 0;
795 for (; k < in_vec_size - block_size; k += block_size) {
797
798 for (int row = 0; row < results_per_simdgroup; row++) {
799 auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
800 const device T* sl = scales + row * in_vec_size_g;
801 const device T* bl = biases + row * in_vec_size_g;
802
803 U s = sl[0];
804 U b = bl[0];
805 result[row] +=
806 qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
807 }
808
809 ws += block_size * bytes_per_pack / pack_factor;
810 scales += block_size / group_size;
811 biases += block_size / group_size;
812 x += block_size;
813 }
814 const int remaining = clamp(
815 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
816 0,
817 values_per_thread);
818 if (remaining > 0) {
820 x, x_thread, remaining);
821
822 for (int row = 0; row < results_per_simdgroup; row++) {
823 auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
824 const device T* sl = scales + row * in_vec_size_g;
825 const device T* bl = biases + row * in_vec_size_g;
826
827 U s = sl[0];
828 U b = bl[0];
830 wl, x_thread, s, b, sum, remaining);
831 }
832 }
833 for (int row = 0; row < results_per_simdgroup; row++) {
834 result[row] = simd_sum(result[row]);
835 if (simd_lid == 0) {
836 y[row] = static_cast<T>(result[row]);
837 }
838 }
839 }
840}
841
842template <typename T, const int group_size, const int bits>
843METAL_FUNC void qvm_impl(
844 const device uint32_t* w,
845 const device T* scales,
846 const device T* biases,
847 const device T* x,
848 device T* y,
849 const int in_vec_size,
850 const int out_vec_size,
851 uint3 tid [[threadgroup_position_in_grid]],
852 uint simd_gid [[simdgroup_index_in_threadgroup]],
853 uint simd_lid [[thread_index_in_simdgroup]]) {
854 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
855 constexpr int num_simdgroups = 2;
856 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
857 constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
858 constexpr int tn = 32 / pack_factor;
859 constexpr int block_size = SIMD_SIZE;
860
861 const device uint8_t* ws = (const device uint8_t*)w;
862
863 typedef float U;
864 typedef struct {
865 uint8_t wi[tn * bytes_per_pack];
866 } vec_w;
867
868 thread vec_w w_local;
869 thread U result[tn * pack_factor] = {0};
870 thread U scale = 1;
871 thread U bias = 0;
872 thread U x_local = 0;
873
874 // Adjust positions
875 const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
876 const int out_vec_size_g = out_vec_size / group_size;
877 int out_col = pack_factor * tn * (tid.x * num_simdgroups + simd_gid);
878 ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;
879 scales += out_col / group_size + simd_lid * out_vec_size_g;
880 biases += out_col / group_size + simd_lid * out_vec_size_g;
881 x += tid.y * in_vec_size + simd_lid;
882 y += tid.y * out_vec_size + out_col;
883
884 if (out_col >= out_vec_size) {
885 return;
886 }
887
888 // Loop over in_vec in blocks of block_size
889 int remaining = in_vec_size % block_size;
890 if (remaining == 0) {
891 for (int i = 0; i < in_vec_size; i += block_size) {
892 x_local = *x;
893 scale = *scales;
894 bias = *biases;
895 w_local = *((device vec_w*)ws);
897 (thread uint8_t*)&w_local, x_local, scale, bias, result);
898
899 x += block_size;
900 scales += block_size * out_vec_size_g;
901 biases += block_size * out_vec_size_g;
902 ws += block_size * out_vec_size_w;
903 }
904 } else {
905 for (int i = block_size; i < in_vec_size; i += block_size) {
906 x_local = *x;
907 scale = *scales;
908 bias = *biases;
909 w_local = *((device vec_w*)ws);
910
912 (thread uint8_t*)&w_local, x_local, scale, bias, result);
913
914 x += block_size;
915 scales += block_size * out_vec_size_g;
916 biases += block_size * out_vec_size_g;
917 ws += block_size * out_vec_size_w;
918 }
919 if (static_cast<int>(simd_lid) < remaining) {
920 x_local = *x;
921 scale = *scales;
922 bias = *biases;
923 w_local = *((device vec_w*)ws);
924 } else {
925 x_local = 0;
926 scale = 0;
927 bias = 0;
928 }
930 (thread uint8_t*)&w_local, x_local, scale, bias, result);
931 }
932
933// Accumulate in the simdgroup
934#pragma clang loop unroll(full)
935 for (int k = 0; k < tn * pack_factor; k++) {
936 result[k] = simd_sum(result[k]);
937 }
938
939 // Store the result
940 if (simd_lid == 0) {
941#pragma clang loop unroll(full)
942 for (int k = 0; k < tn * pack_factor; k++) {
943 y[k] = static_cast<T>(result[k]);
944 }
945 }
946}
947
948template <
949 typename T,
950 const int group_size,
951 const int bits,
952 const bool aligned_N,
953 const int BM = 32,
954 const int BK = 32,
955 const int BN = 32>
956METAL_FUNC void qmm_t_impl(
957 const device uint32_t* w,
958 const device T* scales,
959 const device T* biases,
960 const device T* x,
961 device T* y,
962 threadgroup T* Xs,
963 threadgroup T* Ws,
964 const constant int& K,
965 const constant int& N,
966 const constant int& M,
967 uint3 tid [[threadgroup_position_in_grid]],
968 uint lid [[thread_index_in_threadgroup]],
969 uint simd_gid [[simdgroup_index_in_threadgroup]],
970 uint simd_lid [[thread_index_in_simdgroup]]) {
971 static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
972 static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
973
974 (void)lid;
975
976 constexpr int WM = 2;
977 constexpr int WN = 2;
978 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
979 constexpr int BK_padded = (BK + 16 / sizeof(T));
980 constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
981
982 // Instantiate the appropriate BlockMMA and Loader
983 using mma_t = mlx::steel::
984 BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
985 using loader_x_t =
987 using loader_w_t = QuantizedBlockLoader<
988 T,
989 BN,
990 BK,
991 BK_padded,
992 1,
993 WM * WN * SIMD_SIZE,
994 group_size,
995 bits>;
996
997 // Set the block
998 const int K_w = K * bytes_per_pack / pack_factor;
999 const int K_g = K / group_size;
1000 const int y_row = tid.y * BM;
1001 const int y_col = tid.x * BN;
1002
1003 auto wl = (const device uint8_t*)w;
1004
1005 x += y_row * K;
1006 wl += y_col * K_w;
1007 scales += y_col * K_g;
1008 biases += y_col * K_g;
1009 y += y_row * N + y_col;
1010
1011 // Make the x loader and mma operation
1012 const short num_els = min(BM, M - y_row);
1013 const short num_outs = min(BN, N - y_col);
1014 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
1015 loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
1016 mma_t mma_op(simd_gid, simd_lid);
1017
1018 if (num_els < BM) {
1019 if (!aligned_N && num_outs < BN) {
1020 for (int k = 0; k < K; k += BK) {
1021 threadgroup_barrier(mem_flags::mem_threadgroup);
1022 loader_x.load_safe(short2(BK, num_els));
1023 loader_w.load_safe(short2(BK, num_outs));
1024 threadgroup_barrier(mem_flags::mem_threadgroup);
1025 mma_op.mma(Xs, Ws);
1026 loader_x.next();
1027 loader_w.next();
1028 }
1029 } else {
1030 for (int k = 0; k < K; k += BK) {
1031 threadgroup_barrier(mem_flags::mem_threadgroup);
1032 loader_x.load_safe(short2(BK, num_els));
1033 loader_w.load_unsafe();
1034 threadgroup_barrier(mem_flags::mem_threadgroup);
1035 mma_op.mma(Xs, Ws);
1036 loader_x.next();
1037 loader_w.next();
1038 }
1039 }
1040 } else {
1041 if (!aligned_N && num_outs < BN) {
1042 for (int k = 0; k < K; k += BK) {
1043 threadgroup_barrier(mem_flags::mem_threadgroup);
1044 loader_x.load_unsafe();
1045 loader_w.load_safe(short2(BK, num_outs));
1046 threadgroup_barrier(mem_flags::mem_threadgroup);
1047 mma_op.mma(Xs, Ws);
1048 loader_x.next();
1049 loader_w.next();
1050 }
1051 } else {
1052 for (int k = 0; k < K; k += BK) {
1053 threadgroup_barrier(mem_flags::mem_threadgroup);
1054 loader_x.load_unsafe();
1055 loader_w.load_unsafe();
1056 threadgroup_barrier(mem_flags::mem_threadgroup);
1057
1058 mma_op.mma(Xs, Ws);
1059 loader_x.next();
1060 loader_w.next();
1061 }
1062 }
1063 }
1064
1065 // Store results to device memory
1066 threadgroup_barrier(mem_flags::mem_threadgroup);
1067 if (num_els < BM || num_outs < BN) {
1068 mma_op.store_result_safe(y, N, short2(num_outs, num_els));
1069 } else {
1070 mma_op.store_result(y, N);
1071 }
1072}
1073
1074template <
1075 typename T,
1076 const int group_size,
1077 const int bits,
1078 const int BM = 32,
1079 const int BK = 32,
1080 const int BN = 32>
1081METAL_FUNC void qmm_n_impl(
1082 const device uint32_t* w,
1083 const device T* scales,
1084 const device T* biases,
1085 const device T* x,
1086 device T* y,
1087 threadgroup T* Xs,
1088 threadgroup T* Ws,
1089 const constant int& K,
1090 const constant int& N,
1091 const constant int& M,
1092 uint3 tid [[threadgroup_position_in_grid]],
1093 uint lid [[thread_index_in_threadgroup]],
1094 uint simd_gid [[simdgroup_index_in_threadgroup]],
1095 uint simd_lid [[thread_index_in_simdgroup]]) {
1096 static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
1097 static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
1098
1099 (void)lid;
1100
1101 constexpr int WM = 2;
1102 constexpr int WN = 2;
1103 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
1104 constexpr int BK_padded = (BK + 16 / sizeof(T));
1105 constexpr int BN_padded = (BN + 16 / sizeof(T));
1106 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
1107 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
1108
1109 // Instantiate the appropriate BlockMMA and Loader
1110 using mma_t = mlx::steel::
1111 BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
1112 using loader_x_t = mlx::steel::
1113 BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
1114 using loader_w_t = QuantizedBlockLoader<
1115 T,
1116 BK,
1117 BN,
1118 BN_padded,
1119 0,
1120 WM * WN * SIMD_SIZE,
1121 group_size,
1122 bits>;
1123
1124 auto wl = (const device uint8_t*)w;
1125
1126 // Set the block
1127 const int y_row = tid.y * BM;
1128 const int y_col = tid.x * BN;
1129 x += y_row * K;
1130 wl += y_col * bytes_per_pack / pack_factor;
1131 scales += y_col / group_size;
1132 biases += y_col / group_size;
1133 y += y_row * N + y_col;
1134
1135 // Make the x loader and mma operation
1136 const short num_els = min(BM, M - y_row);
1137 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
1138 loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid);
1139 mma_t mma_op(simd_gid, simd_lid);
1140
1141 if (num_els < BM) {
1142 if ((K % BK) != 0) {
1143 const int k_blocks = K / BK;
1144 for (int k = 0; k < k_blocks; k++) {
1145 threadgroup_barrier(mem_flags::mem_threadgroup);
1146 loader_x.load_safe(short2(BK, num_els));
1147 loader_w.load_unsafe();
1148 threadgroup_barrier(mem_flags::mem_threadgroup);
1149 mma_op.mma(Xs, Ws);
1150 loader_x.next();
1151 loader_w.next();
1152 }
1153 const short num_k = K - k_blocks * BK;
1154 threadgroup_barrier(mem_flags::mem_threadgroup);
1155 loader_x.load_safe(short2(num_k, num_els));
1156 loader_w.load_safe(short2(BN, num_k));
1157 threadgroup_barrier(mem_flags::mem_threadgroup);
1158 mma_op.mma(Xs, Ws);
1159 } else {
1160 for (int k = 0; k < K; k += BK) {
1161 threadgroup_barrier(mem_flags::mem_threadgroup);
1162 loader_x.load_safe(short2(BK, num_els));
1163 loader_w.load_unsafe();
1164 threadgroup_barrier(mem_flags::mem_threadgroup);
1165 mma_op.mma(Xs, Ws);
1166 loader_x.next();
1167 loader_w.next();
1168 }
1169 }
1170 } else {
1171 if ((K % BK) != 0) {
1172 const int k_blocks = K / BK;
1173 for (int k = 0; k < k_blocks; k++) {
1174 threadgroup_barrier(mem_flags::mem_threadgroup);
1175 loader_x.load_unsafe();
1176 loader_w.load_unsafe();
1177 threadgroup_barrier(mem_flags::mem_threadgroup);
1178 mma_op.mma(Xs, Ws);
1179 loader_x.next();
1180 loader_w.next();
1181 }
1182 const short num_k = K - k_blocks * BK;
1183 threadgroup_barrier(mem_flags::mem_threadgroup);
1184 loader_x.load_safe(short2(num_k, BM));
1185 loader_w.load_safe(short2(BN, num_k));
1186 threadgroup_barrier(mem_flags::mem_threadgroup);
1187 mma_op.mma(Xs, Ws);
1188 } else {
1189 for (int k = 0; k < K; k += BK) {
1190 threadgroup_barrier(mem_flags::mem_threadgroup);
1191 loader_x.load_unsafe();
1192 loader_w.load_unsafe();
1193 threadgroup_barrier(mem_flags::mem_threadgroup);
1194 mma_op.mma(Xs, Ws);
1195 loader_x.next();
1196 loader_w.next();
1197 }
1198 }
1199 }
1200
1201 // Store results to device memory
1202 threadgroup_barrier(mem_flags::mem_threadgroup);
1203 if (num_els < BM) {
1204 mma_op.store_result_safe(y, N, short2(BN, num_els));
1205 } else {
1206 mma_op.store_result(y, N);
1207 }
1208}
1209
1210template <typename T>
1212 const device T*& x,
1213 const device uint32_t*& w,
1214 const device T*& scales,
1215 const device T*& biases,
1216 device T*& y,
1217 int output_stride,
1218 const constant int& x_batch_ndims,
1219 const constant int* x_shape,
1220 const constant size_t* x_strides,
1221 const constant int& w_batch_ndims,
1222 const constant int* w_shape,
1223 const constant size_t* w_strides,
1224 const constant size_t* s_strides,
1225 const constant size_t* b_strides,
1226 uint3 tid [[threadgroup_position_in_grid]]) {
1227 // Set the input/output matrices
1228 uint32_t x_idx = tid.z;
1229 uint32_t w_idx = tid.z;
1230 if (x_batch_ndims == 1) {
1231 x += x_idx * x_strides[0];
1232 } else {
1233 x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1234 }
1235 if (w_batch_ndims == 1) {
1236 w += w_idx * w_strides[0];
1237 scales += w_idx * s_strides[0];
1238 biases += w_idx * b_strides[0];
1239 } else {
1240 ulong3 idx = elem_to_loc_broadcast(
1241 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1242 w += idx.x;
1243 scales += idx.y;
1244 biases += idx.z;
1245 }
1246 y += tid.z * output_stride;
1247}
1248
1249template <typename T>
1251 const device T*& x,
1252 const device uint32_t*& w,
1253 const device T*& scales,
1254 const device T*& biases,
1255 const device uint32_t* lhs_indices,
1256 const device uint32_t* rhs_indices,
1257 device T*& y,
1258 int output_stride,
1259 const constant int& batch_ndims,
1260 const constant int* batch_shape,
1261 const constant size_t* lhs_strides,
1262 const constant size_t* rhs_strides,
1263 const constant int& x_batch_ndims,
1264 const constant int* x_shape,
1265 const constant size_t* x_strides,
1266 const constant int& w_batch_ndims,
1267 const constant int* w_shape,
1268 const constant size_t* w_strides,
1269 const constant size_t* s_strides,
1270 const constant size_t* b_strides,
1271 uint3 tid [[threadgroup_position_in_grid]]) {
1272 // Set the input/output matrices
1273 uint32_t x_idx;
1274 uint32_t w_idx;
1275 if (batch_ndims == 1) {
1276 x_idx = lhs_indices[tid.z * lhs_strides[0]];
1277 w_idx = rhs_indices[tid.z * rhs_strides[0]];
1278 } else {
1279 ulong2 idx = elem_to_loc_broadcast(
1280 tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
1281 x_idx = lhs_indices[idx.x];
1282 w_idx = rhs_indices[idx.y];
1283 }
1284 if (x_batch_ndims == 1) {
1285 x += x_idx * x_strides[0];
1286 } else {
1287 x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1288 }
1289 if (w_batch_ndims == 1) {
1290 w += w_idx * w_strides[0];
1291 scales += w_idx * s_strides[0];
1292 biases += w_idx * b_strides[0];
1293 } else {
1294 ulong3 idx = elem_to_loc_broadcast(
1295 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1296 w += idx.x;
1297 scales += idx.y;
1298 biases += idx.z;
1299 }
1300 y += tid.z * output_stride;
1301}
1302
1303template <typename T, int group_size, int bits, int D, bool batched>
1304[[kernel]] void qmv_quad(
1305 const device uint32_t* w [[buffer(0)]],
1306 const device T* scales [[buffer(1)]],
1307 const device T* biases [[buffer(2)]],
1308 const device T* x [[buffer(3)]],
1309 device T* y [[buffer(4)]],
1310 const constant int& in_vec_size [[buffer(5)]],
1311 const constant int& out_vec_size [[buffer(6)]],
1312 const constant int& x_batch_ndims [[buffer(7)]],
1313 const constant int* x_shape [[buffer(8)]],
1314 const constant size_t* x_strides [[buffer(9)]],
1315 const constant int& w_batch_ndims [[buffer(10)]],
1316 const constant int* w_shape [[buffer(11)]],
1317 const constant size_t* w_strides [[buffer(12)]],
1318 const constant size_t* s_strides [[buffer(13)]],
1319 const constant size_t* b_strides [[buffer(14)]],
1320 uint3 tid [[threadgroup_position_in_grid]],
1321 uint quad_gid [[quadgroup_index_in_threadgroup]],
1322 uint quad_lid [[thread_index_in_quadgroup]]) {
1323 if (batched) {
1325 x,
1326 w,
1327 scales,
1328 biases,
1329 y,
1330 out_vec_size,
1331 x_batch_ndims,
1332 x_shape,
1333 x_strides,
1334 w_batch_ndims,
1335 w_shape,
1336 w_strides,
1337 s_strides,
1338 b_strides,
1339 tid);
1340 }
1342 w,
1343 scales,
1344 biases,
1345 x,
1346 y,
1347 in_vec_size,
1348 out_vec_size,
1349 tid,
1350 quad_gid,
1351 quad_lid);
1352}
1353
1354template <typename T, int group_size, int bits, bool batched>
1355[[kernel]] void qmv_fast(
1356 const device uint32_t* w [[buffer(0)]],
1357 const device T* scales [[buffer(1)]],
1358 const device T* biases [[buffer(2)]],
1359 const device T* x [[buffer(3)]],
1360 device T* y [[buffer(4)]],
1361 const constant int& in_vec_size [[buffer(5)]],
1362 const constant int& out_vec_size [[buffer(6)]],
1363 const constant int& x_batch_ndims [[buffer(7)]],
1364 const constant int* x_shape [[buffer(8)]],
1365 const constant size_t* x_strides [[buffer(9)]],
1366 const constant int& w_batch_ndims [[buffer(10)]],
1367 const constant int* w_shape [[buffer(11)]],
1368 const constant size_t* w_strides [[buffer(12)]],
1369 const constant size_t* s_strides [[buffer(13)]],
1370 const constant size_t* b_strides [[buffer(14)]],
1371 uint3 tid [[threadgroup_position_in_grid]],
1372 uint simd_gid [[simdgroup_index_in_threadgroup]],
1373 uint simd_lid [[thread_index_in_simdgroup]]) {
1374 if (batched) {
1376 x,
1377 w,
1378 scales,
1379 biases,
1380 y,
1381 out_vec_size,
1382 x_batch_ndims,
1383 x_shape,
1384 x_strides,
1385 w_batch_ndims,
1386 w_shape,
1387 w_strides,
1388 s_strides,
1389 b_strides,
1390 tid);
1391 }
1393 w,
1394 scales,
1395 biases,
1396 x,
1397 y,
1398 in_vec_size,
1399 out_vec_size,
1400 tid,
1401 simd_gid,
1402 simd_lid);
1403}
1404
1405template <typename T, const int group_size, const int bits, bool batched>
1406[[kernel]] void qmv(
1407 const device uint32_t* w [[buffer(0)]],
1408 const device T* scales [[buffer(1)]],
1409 const device T* biases [[buffer(2)]],
1410 const device T* x [[buffer(3)]],
1411 device T* y [[buffer(4)]],
1412 const constant int& in_vec_size [[buffer(5)]],
1413 const constant int& out_vec_size [[buffer(6)]],
1414 const constant int& x_batch_ndims [[buffer(7)]],
1415 const constant int* x_shape [[buffer(8)]],
1416 const constant size_t* x_strides [[buffer(9)]],
1417 const constant int& w_batch_ndims [[buffer(10)]],
1418 const constant int* w_shape [[buffer(11)]],
1419 const constant size_t* w_strides [[buffer(12)]],
1420 const constant size_t* s_strides [[buffer(13)]],
1421 const constant size_t* b_strides [[buffer(14)]],
1422 uint3 tid [[threadgroup_position_in_grid]],
1423 uint simd_gid [[simdgroup_index_in_threadgroup]],
1424 uint simd_lid [[thread_index_in_simdgroup]]) {
1425 if (batched) {
1427 x,
1428 w,
1429 scales,
1430 biases,
1431 y,
1432 out_vec_size,
1433 x_batch_ndims,
1434 x_shape,
1435 x_strides,
1436 w_batch_ndims,
1437 w_shape,
1438 w_strides,
1439 s_strides,
1440 b_strides,
1441 tid);
1442 }
1444 w,
1445 scales,
1446 biases,
1447 x,
1448 y,
1449 in_vec_size,
1450 out_vec_size,
1451 tid,
1452 simd_gid,
1453 simd_lid);
1454}
1455
1456template <typename T, const int group_size, const int bits, bool batched>
1457[[kernel]] void qvm(
1458 const device uint32_t* w [[buffer(0)]],
1459 const device T* scales [[buffer(1)]],
1460 const device T* biases [[buffer(2)]],
1461 const device T* x [[buffer(3)]],
1462 device T* y [[buffer(4)]],
1463 const constant int& in_vec_size [[buffer(5)]],
1464 const constant int& out_vec_size [[buffer(6)]],
1465 const constant int& x_batch_ndims [[buffer(7)]],
1466 const constant int* x_shape [[buffer(8)]],
1467 const constant size_t* x_strides [[buffer(9)]],
1468 const constant int& w_batch_ndims [[buffer(10)]],
1469 const constant int* w_shape [[buffer(11)]],
1470 const constant size_t* w_strides [[buffer(12)]],
1471 const constant size_t* s_strides [[buffer(13)]],
1472 const constant size_t* b_strides [[buffer(14)]],
1473 uint3 tid [[threadgroup_position_in_grid]],
1474 uint simd_gid [[simdgroup_index_in_threadgroup]],
1475 uint simd_lid [[thread_index_in_simdgroup]]) {
1476 if (batched) {
1478 x,
1479 w,
1480 scales,
1481 biases,
1482 y,
1483 out_vec_size,
1484 x_batch_ndims,
1485 x_shape,
1486 x_strides,
1487 w_batch_ndims,
1488 w_shape,
1489 w_strides,
1490 s_strides,
1491 b_strides,
1492 tid);
1493 }
1495 w,
1496 scales,
1497 biases,
1498 x,
1499 y,
1500 in_vec_size,
1501 out_vec_size,
1502 tid,
1503 simd_gid,
1504 simd_lid);
1505}
1506
1507template <typename T, const int group_size, const int bits, int split_k = 32>
1508[[kernel]] void qvm_split_k(
1509 const device uint32_t* w [[buffer(0)]],
1510 const device T* scales [[buffer(1)]],
1511 const device T* biases [[buffer(2)]],
1512 const device T* x [[buffer(3)]],
1513 device T* y [[buffer(4)]],
1514 const constant int& in_vec_size [[buffer(5)]],
1515 const constant int& out_vec_size [[buffer(6)]],
1516 const constant int& x_batch_ndims [[buffer(7)]],
1517 const constant int* x_shape [[buffer(8)]],
1518 const constant size_t* x_strides [[buffer(9)]],
1519 const constant int& w_batch_ndims [[buffer(10)]],
1520 const constant int* w_shape [[buffer(11)]],
1521 const constant size_t* w_strides [[buffer(12)]],
1522 const constant size_t* s_strides [[buffer(13)]],
1523 const constant size_t* b_strides [[buffer(14)]],
1524 const constant int& final_block_size [[buffer(15)]],
1525 uint3 tid [[threadgroup_position_in_grid]],
1526 uint simd_gid [[simdgroup_index_in_threadgroup]],
1527 uint simd_lid [[thread_index_in_simdgroup]]) {
1529 x,
1530 w,
1531 scales,
1532 biases,
1533 y,
1534 out_vec_size,
1535 x_batch_ndims,
1536 x_shape,
1537 x_strides,
1538 w_batch_ndims,
1539 w_shape,
1540 w_strides,
1541 s_strides,
1542 b_strides,
1543 tid);
1544
1545 // When (in_vec_size % split_k != 0) the final block needs to be smaller
1546 int in_vec_size_adj =
1547 tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
1548
1550 w,
1551 scales,
1552 biases,
1553 x,
1554 y,
1555 in_vec_size_adj,
1556 out_vec_size,
1557 tid,
1558 simd_gid,
1559 simd_lid);
1560}
1561
1562template <
1563 typename T,
1564 const int group_size,
1565 const int bits,
1566 const bool aligned_N,
1567 const bool batched,
1568 const int BM = 32,
1569 const int BK = 32,
1570 const int BN = 32>
1571[[kernel]] void qmm_t(
1572 const device uint32_t* w [[buffer(0)]],
1573 const device T* scales [[buffer(1)]],
1574 const device T* biases [[buffer(2)]],
1575 const device T* x [[buffer(3)]],
1576 device T* y [[buffer(4)]],
1577 const constant int& K [[buffer(5)]],
1578 const constant int& N [[buffer(6)]],
1579 const constant int& M [[buffer(7)]],
1580 const constant int& x_batch_ndims [[buffer(8)]],
1581 const constant int* x_shape [[buffer(9)]],
1582 const constant size_t* x_strides [[buffer(10)]],
1583 const constant int& w_batch_ndims [[buffer(11)]],
1584 const constant int* w_shape [[buffer(12)]],
1585 const constant size_t* w_strides [[buffer(13)]],
1586 const constant size_t* s_strides [[buffer(14)]],
1587 const constant size_t* b_strides [[buffer(15)]],
1588 uint3 tid [[threadgroup_position_in_grid]],
1589 uint lid [[thread_index_in_threadgroup]],
1590 uint simd_gid [[simdgroup_index_in_threadgroup]],
1591 uint simd_lid [[thread_index_in_simdgroup]]) {
1592 (void)lid;
1593
1594 constexpr int BK_padded = (BK + 16 / sizeof(T));
1595
1596 threadgroup T Xs[BM * BK_padded];
1597 threadgroup T Ws[BN * BK_padded];
1598
1599 if (batched) {
1601 x,
1602 w,
1603 scales,
1604 biases,
1605 y,
1606 M * N,
1607 x_batch_ndims,
1608 x_shape,
1609 x_strides,
1610 w_batch_ndims,
1611 w_shape,
1612 w_strides,
1613 s_strides,
1614 b_strides,
1615 tid);
1616 }
1618 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1619}
1620
1621template <
1622 typename T,
1623 const int group_size,
1624 const int bits,
1625 const bool batched,
1626 const int BM = 32,
1627 const int BK = 32,
1628 const int BN = 32>
1629[[kernel]] void qmm_n(
1630 const device uint32_t* w [[buffer(0)]],
1631 const device T* scales [[buffer(1)]],
1632 const device T* biases [[buffer(2)]],
1633 const device T* x [[buffer(3)]],
1634 device T* y [[buffer(4)]],
1635 const constant int& K [[buffer(5)]],
1636 const constant int& N [[buffer(6)]],
1637 const constant int& M [[buffer(7)]],
1638 const constant int& x_batch_ndims [[buffer(8)]],
1639 const constant int* x_shape [[buffer(9)]],
1640 const constant size_t* x_strides [[buffer(10)]],
1641 const constant int& w_batch_ndims [[buffer(11)]],
1642 const constant int* w_shape [[buffer(12)]],
1643 const constant size_t* w_strides [[buffer(13)]],
1644 const constant size_t* s_strides [[buffer(14)]],
1645 const constant size_t* b_strides [[buffer(15)]],
1646 uint3 tid [[threadgroup_position_in_grid]],
1647 uint lid [[thread_index_in_threadgroup]],
1648 uint simd_gid [[simdgroup_index_in_threadgroup]],
1649 uint simd_lid [[thread_index_in_simdgroup]]) {
1650 (void)lid;
1651
1652 constexpr int BK_padded = (BK + 16 / sizeof(T));
1653 constexpr int BN_padded = (BN + 16 / sizeof(T));
1654
1655 threadgroup T Xs[BM * BK_padded];
1656 threadgroup T Ws[BK * BN_padded];
1657
1658 if (batched) {
1660 x,
1661 w,
1662 scales,
1663 biases,
1664 y,
1665 M * N,
1666 x_batch_ndims,
1667 x_shape,
1668 x_strides,
1669 w_batch_ndims,
1670 w_shape,
1671 w_strides,
1672 s_strides,
1673 b_strides,
1674 tid);
1675 }
1676
1678 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1679}
1680
1681template <typename T, int group_size, int bits>
1682[[kernel]] void bs_qmv_fast(
1683 const device uint32_t* w [[buffer(0)]],
1684 const device T* scales [[buffer(1)]],
1685 const device T* biases [[buffer(2)]],
1686 const device T* x [[buffer(3)]],
1687 device T* y [[buffer(4)]],
1688 const constant int& in_vec_size [[buffer(5)]],
1689 const constant int& out_vec_size [[buffer(6)]],
1690 const constant int& x_batch_ndims [[buffer(7)]],
1691 const constant int* x_shape [[buffer(8)]],
1692 const constant size_t* x_strides [[buffer(9)]],
1693 const constant int& w_batch_ndims [[buffer(10)]],
1694 const constant int* w_shape [[buffer(11)]],
1695 const constant size_t* w_strides [[buffer(12)]],
1696 const constant size_t* s_strides [[buffer(13)]],
1697 const constant size_t* b_strides [[buffer(14)]],
1698 const constant int& batch_ndims [[buffer(15)]],
1699 const constant int* batch_shape [[buffer(16)]],
1700 const device uint32_t* lhs_indices [[buffer(17)]],
1701 const device uint32_t* rhs_indices [[buffer(18)]],
1702 const constant size_t* lhs_strides [[buffer(19)]],
1703 const constant size_t* rhs_strides [[buffer(20)]],
1704 uint3 tid [[threadgroup_position_in_grid]],
1705 uint simd_gid [[simdgroup_index_in_threadgroup]],
1706 uint simd_lid [[thread_index_in_simdgroup]]) {
1708 x,
1709 w,
1710 scales,
1711 biases,
1712 lhs_indices,
1713 rhs_indices,
1714 y,
1715 out_vec_size,
1716 batch_ndims,
1717 batch_shape,
1718 lhs_strides,
1719 rhs_strides,
1720 x_batch_ndims,
1721 x_shape,
1722 x_strides,
1723 w_batch_ndims,
1724 w_shape,
1725 w_strides,
1726 s_strides,
1727 b_strides,
1728 tid);
1730 w,
1731 scales,
1732 biases,
1733 x,
1734 y,
1735 in_vec_size,
1736 out_vec_size,
1737 tid,
1738 simd_gid,
1739 simd_lid);
1740}
1741
1742template <typename T, int group_size, int bits>
1743[[kernel]] void bs_qmv(
1744 const device uint32_t* w [[buffer(0)]],
1745 const device T* scales [[buffer(1)]],
1746 const device T* biases [[buffer(2)]],
1747 const device T* x [[buffer(3)]],
1748 device T* y [[buffer(4)]],
1749 const constant int& in_vec_size [[buffer(5)]],
1750 const constant int& out_vec_size [[buffer(6)]],
1751 const constant int& x_batch_ndims [[buffer(7)]],
1752 const constant int* x_shape [[buffer(8)]],
1753 const constant size_t* x_strides [[buffer(9)]],
1754 const constant int& w_batch_ndims [[buffer(10)]],
1755 const constant int* w_shape [[buffer(11)]],
1756 const constant size_t* w_strides [[buffer(12)]],
1757 const constant size_t* s_strides [[buffer(13)]],
1758 const constant size_t* b_strides [[buffer(14)]],
1759 const constant int& batch_ndims [[buffer(15)]],
1760 const constant int* batch_shape [[buffer(16)]],
1761 const device uint32_t* lhs_indices [[buffer(17)]],
1762 const device uint32_t* rhs_indices [[buffer(18)]],
1763 const constant size_t* lhs_strides [[buffer(19)]],
1764 const constant size_t* rhs_strides [[buffer(20)]],
1765 uint3 tid [[threadgroup_position_in_grid]],
1766 uint simd_gid [[simdgroup_index_in_threadgroup]],
1767 uint simd_lid [[thread_index_in_simdgroup]]) {
1769 x,
1770 w,
1771 scales,
1772 biases,
1773 lhs_indices,
1774 rhs_indices,
1775 y,
1776 out_vec_size,
1777 batch_ndims,
1778 batch_shape,
1779 lhs_strides,
1780 rhs_strides,
1781 x_batch_ndims,
1782 x_shape,
1783 x_strides,
1784 w_batch_ndims,
1785 w_shape,
1786 w_strides,
1787 s_strides,
1788 b_strides,
1789 tid);
1791 w,
1792 scales,
1793 biases,
1794 x,
1795 y,
1796 in_vec_size,
1797 out_vec_size,
1798 tid,
1799 simd_gid,
1800 simd_lid);
1801}
1802
1803template <typename T, int group_size, int bits>
1804[[kernel]] void bs_qvm(
1805 const device uint32_t* w [[buffer(0)]],
1806 const device T* scales [[buffer(1)]],
1807 const device T* biases [[buffer(2)]],
1808 const device T* x [[buffer(3)]],
1809 device T* y [[buffer(4)]],
1810 const constant int& in_vec_size [[buffer(5)]],
1811 const constant int& out_vec_size [[buffer(6)]],
1812 const constant int& x_batch_ndims [[buffer(7)]],
1813 const constant int* x_shape [[buffer(8)]],
1814 const constant size_t* x_strides [[buffer(9)]],
1815 const constant int& w_batch_ndims [[buffer(10)]],
1816 const constant int* w_shape [[buffer(11)]],
1817 const constant size_t* w_strides [[buffer(12)]],
1818 const constant size_t* s_strides [[buffer(13)]],
1819 const constant size_t* b_strides [[buffer(14)]],
1820 const constant int& batch_ndims [[buffer(15)]],
1821 const constant int* batch_shape [[buffer(16)]],
1822 const device uint32_t* lhs_indices [[buffer(17)]],
1823 const device uint32_t* rhs_indices [[buffer(18)]],
1824 const constant size_t* lhs_strides [[buffer(19)]],
1825 const constant size_t* rhs_strides [[buffer(20)]],
1826 uint3 tid [[threadgroup_position_in_grid]],
1827 uint simd_gid [[simdgroup_index_in_threadgroup]],
1828 uint simd_lid [[thread_index_in_simdgroup]]) {
1830 x,
1831 w,
1832 scales,
1833 biases,
1834 lhs_indices,
1835 rhs_indices,
1836 y,
1837 out_vec_size,
1838 batch_ndims,
1839 batch_shape,
1840 lhs_strides,
1841 rhs_strides,
1842 x_batch_ndims,
1843 x_shape,
1844 x_strides,
1845 w_batch_ndims,
1846 w_shape,
1847 w_strides,
1848 s_strides,
1849 b_strides,
1850 tid);
1852 w,
1853 scales,
1854 biases,
1855 x,
1856 y,
1857 in_vec_size,
1858 out_vec_size,
1859 tid,
1860 simd_gid,
1861 simd_lid);
1862}
1863
1864template <
1865 typename T,
1866 const int group_size,
1867 const int bits,
1868 const bool aligned_N,
1869 const int BM = 32,
1870 const int BK = 32,
1871 const int BN = 32>
1872[[kernel]] void bs_qmm_t(
1873 const device uint32_t* w [[buffer(0)]],
1874 const device T* scales [[buffer(1)]],
1875 const device T* biases [[buffer(2)]],
1876 const device T* x [[buffer(3)]],
1877 device T* y [[buffer(4)]],
1878 const constant int& K [[buffer(5)]],
1879 const constant int& N [[buffer(6)]],
1880 const constant int& M [[buffer(7)]],
1881 const constant int& x_batch_ndims [[buffer(8)]],
1882 const constant int* x_shape [[buffer(9)]],
1883 const constant size_t* x_strides [[buffer(10)]],
1884 const constant int& w_batch_ndims [[buffer(11)]],
1885 const constant int* w_shape [[buffer(12)]],
1886 const constant size_t* w_strides [[buffer(13)]],
1887 const constant size_t* s_strides [[buffer(14)]],
1888 const constant size_t* b_strides [[buffer(15)]],
1889 const constant int& batch_ndims [[buffer(16)]],
1890 const constant int* batch_shape [[buffer(17)]],
1891 const device uint32_t* lhs_indices [[buffer(18)]],
1892 const device uint32_t* rhs_indices [[buffer(19)]],
1893 const constant size_t* lhs_strides [[buffer(20)]],
1894 const constant size_t* rhs_strides [[buffer(21)]],
1895 uint3 tid [[threadgroup_position_in_grid]],
1896 uint lid [[thread_index_in_threadgroup]],
1897 uint simd_gid [[simdgroup_index_in_threadgroup]],
1898 uint simd_lid [[thread_index_in_simdgroup]]) {
1899 (void)lid;
1900
1901 constexpr int BK_padded = (BK + 16 / sizeof(T));
1902
1903 threadgroup T Xs[BM * BK_padded];
1904 threadgroup T Ws[BN * BK_padded];
1905
1907 x,
1908 w,
1909 scales,
1910 biases,
1911 lhs_indices,
1912 rhs_indices,
1913 y,
1914 M * N,
1915 batch_ndims,
1916 batch_shape,
1917 lhs_strides,
1918 rhs_strides,
1919 x_batch_ndims,
1920 x_shape,
1921 x_strides,
1922 w_batch_ndims,
1923 w_shape,
1924 w_strides,
1925 s_strides,
1926 b_strides,
1927 tid);
1929 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1930}
1931
1932template <
1933 typename T,
1934 const int group_size,
1935 const int bits,
1936 const int BM = 32,
1937 const int BK = 32,
1938 const int BN = 32>
1939[[kernel]] void bs_qmm_n(
1940 const device uint32_t* w [[buffer(0)]],
1941 const device T* scales [[buffer(1)]],
1942 const device T* biases [[buffer(2)]],
1943 const device T* x [[buffer(3)]],
1944 device T* y [[buffer(4)]],
1945 const constant int& K [[buffer(5)]],
1946 const constant int& N [[buffer(6)]],
1947 const constant int& M [[buffer(7)]],
1948 const constant int& x_batch_ndims [[buffer(8)]],
1949 const constant int* x_shape [[buffer(9)]],
1950 const constant size_t* x_strides [[buffer(10)]],
1951 const constant int& w_batch_ndims [[buffer(11)]],
1952 const constant int* w_shape [[buffer(12)]],
1953 const constant size_t* w_strides [[buffer(13)]],
1954 const constant size_t* s_strides [[buffer(14)]],
1955 const constant size_t* b_strides [[buffer(15)]],
1956 const constant int& batch_ndims [[buffer(16)]],
1957 const constant int* batch_shape [[buffer(17)]],
1958 const device uint32_t* lhs_indices [[buffer(18)]],
1959 const device uint32_t* rhs_indices [[buffer(19)]],
1960 const constant size_t* lhs_strides [[buffer(20)]],
1961 const constant size_t* rhs_strides [[buffer(21)]],
1962 uint3 tid [[threadgroup_position_in_grid]],
1963 uint lid [[thread_index_in_threadgroup]],
1964 uint simd_gid [[simdgroup_index_in_threadgroup]],
1965 uint simd_lid [[thread_index_in_simdgroup]]) {
1966 (void)lid;
1967
1968 constexpr int BK_padded = (BK + 16 / sizeof(T));
1969 constexpr int BN_padded = (BN + 16 / sizeof(T));
1970
1971 threadgroup T Xs[BM * BK_padded];
1972 threadgroup T Ws[BK * BN_padded];
1973
1975 x,
1976 w,
1977 scales,
1978 biases,
1979 lhs_indices,
1980 rhs_indices,
1981 y,
1982 M * N,
1983 batch_ndims,
1984 batch_shape,
1985 lhs_strides,
1986 rhs_strides,
1987 x_batch_ndims,
1988 x_shape,
1989 x_strides,
1990 w_batch_ndims,
1991 w_shape,
1992 w_strides,
1993 s_strides,
1994 b_strides,
1995 tid);
1997 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1998}
1999
2000template <typename T, const int group_size, const int bits>
2001[[kernel]] void affine_quantize(
2002 const device T* w [[buffer(0)]],
2003 device uint8_t* out [[buffer(1)]],
2004 device T* scales [[buffer(2)]],
2005 device T* biases [[buffer(3)]],
2006 uint2 index [[thread_position_in_grid]],
2007 uint2 grid_dim [[threads_per_grid]]) {
2008 constexpr T eps = T(1e-7);
2009 constexpr int simd_size = 32;
2010 constexpr T n_bins = (1 << bits) - 1;
2011 constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
2012 constexpr int values_per_reduce = group_size / simd_size;
2013 constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
2014 constexpr int writes_per_pack =
2015 writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
2016 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
2017 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
2018
2019 static_assert(
2020 group_size % simd_size == 0,
2021 "Group size must be divisible by simd size.");
2022
2023 size_t offset = index.x + grid_dim.x * size_t(index.y);
2024 size_t in_index = offset * values_per_reduce;
2025 size_t out_index = power_of_2_bits
2026 ? offset * writes_per_pack
2027 : offset * bytes_per_pack / writes_per_reduce;
2028
2029 T w_thread[values_per_reduce];
2030 T w_min = Limits<T>::max;
2031 T w_max = 0;
2032
2033#pragma clang loop unroll(full)
2034 for (int i = 0; i < values_per_reduce; i++) {
2035 T val = w[in_index + i];
2036 w_thread[i] = val;
2037 w_min = min(w_min, val);
2038 w_max = max(w_max, val);
2039 }
2040
2041 w_min = simd_min(w_min);
2042 w_max = simd_max(w_max);
2043
2044 T scale = max((w_max - w_min) / n_bins, eps);
2045 bool side = abs(w_min) > abs(w_max);
2046 scale = side ? scale : -scale;
2047 T edge = side ? w_min : w_max;
2048 T q0 = round(edge / scale);
2049 bool at_zero = q0 == 0.0f;
2050 scale = at_zero ? scale : edge / q0;
2051 T bias = at_zero ? T(0) : edge;
2052
2053 // Write out the scales and biases
2054 size_t gindex = in_index / group_size;
2055 if (in_index % group_size == 0) {
2056 scales[gindex] = scale;
2057 biases[gindex] = bias;
2058 }
2059
2060 // We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
2061 uint32_t output = 0;
2062
2063#pragma clang loop unroll(full)
2064 for (int i = 0; i < values_per_reduce; i++) {
2065 uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
2066 if (bits == 8) {
2067 output = val;
2068 } else {
2069 output += val << (bits * (i % packs_per_int));
2070 }
2071
2072 if (packs_per_int < values_per_reduce &&
2073 i % packs_per_int == packs_per_int - 1) {
2074 out[out_index + i / packs_per_int] = output;
2075 output = 0;
2076 } else {
2077#pragma clang loop unroll(full)
2078 for (int j = 1; j < writes_per_reduce; j++) {
2079 uint8_t sval = simd_shuffle_down(val, j);
2080 output += sval << (bits * (j * values_per_reduce + i));
2081 }
2082 }
2083 }
2084 if (bits == 3 || bits == 6) {
2085 if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) {
2086 out[out_index] = output & 0xff;
2087 out[out_index + 1] = (output & 0xff00) >> 8;
2088 out[out_index + 2] = (output & 0xff0000) >> 16;
2089 }
2090 } else {
2091 if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
2092 out[out_index / writes_per_reduce] = output;
2093 }
2094 }
2095}
2096
2097template <typename T, const int group_size, const int bits>
2098[[kernel]] void affine_dequantize(
2099 const device uint8_t* w [[buffer(0)]],
2100 const device T* scales [[buffer(1)]],
2101 const device T* biases [[buffer(2)]],
2102 device T* out [[buffer(3)]],
2103 uint2 index [[thread_position_in_grid]],
2104 uint2 grid_dim [[threads_per_grid]]) {
2105 constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
2106 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
2107 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
2108
2109 size_t offset = index.x + grid_dim.x * size_t(index.y);
2110 size_t oindex = offset * packs_per_int;
2111 size_t gindex = oindex / group_size;
2112 T scale = scales[gindex];
2113 T bias = biases[gindex];
2114
2115 out += oindex;
2116
2117 if (bits == 3) {
2118 w += offset * bytes_per_pack;
2119 out[0] = (w[0] & 0x7) * scale + bias;
2120 out[1] = ((w[0] & 0x38) >> 3) * scale + bias;
2121 out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
2122 out[3] = ((w[1] & 0xe) >> 1) * scale + bias;
2123 out[4] = ((w[1] & 0x70) >> 4) * scale + bias;
2124 out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
2125 out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
2126 out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
2127
2128 } else if (bits == 6) {
2129 w += offset * bytes_per_pack;
2130 out[0] = (w[0] & 0x3f) * scale + bias;
2131 out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
2132 out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
2133 out[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
2134 } else {
2135 uint val = w[offset];
2136#pragma clang loop unroll(full)
2137 for (int i = 0; i < packs_per_int; i++) {
2138 uint8_t d;
2139 if (bits == 2) {
2140 d = (val >> (bits * i)) & 0x03;
2141 } else if (bits == 4) {
2142 d = (val >> (bits * i)) & 0x0f;
2143 } else if (bits == 8) {
2144 d = val;
2145 }
2146 out[i] = scale * d + bias;
2147 }
2148 }
2149}
static constant constexpr const uint8_t simd_size
Definition ops.h:22
METAL_FUNC ulong2 elem_to_loc_broadcast(uint elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
Definition utils.h:7
METAL_FUNC IdxT elem_to_loc(uint elem, constant const int *shape, constant const StrideT *strides, int ndim)
Definition utils.h:93
Definition bf16_math.h:226
METAL_FUNC bfloat16_t simd_max(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t round(bfloat16_t x)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t simd_sum(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t abs(bfloat16_t x)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t simd_min(bfloat16_t data)
Definition bf16_math.h:378
METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta)
Definition bf16_math.h:377
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
#define MLX_MTL_CONST
Definition quantized.h:8
U qdot_safe(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum, int N)
Definition quantized.h:225
METAL_FUNC void qmm_n_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1081
METAL_FUNC void qvm_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const int in_vec_size, const int out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:843
void bs_qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1939
void qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1629
void affine_quantize(const device T *w, device uint8_t *out, device T *scales, device T *biases, uint2 index, uint2 grid_dim)
Definition quantized.h:2001
void bs_qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1682
void affine_dequantize(const device uint8_t *w, const device T *scales, const device T *biases, device T *out, uint2 index, uint2 grid_dim)
Definition quantized.h:2098
static constant constexpr const int SIMD_SIZE
Definition quantized.h:10
void qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1406
void bs_qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1804
void qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1355
void qmv_quad(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:1304
static constant constexpr const int QUAD_SIZE
Definition quantized.h:11
U load_vector(const device T *x, thread U *x_thread)
Definition quantized.h:14
METAL_FUNC void qmv_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:688
U load_vector_safe(const device T *x, thread U *x_thread, int N)
Definition quantized.h:77
void bs_qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1872
U qdot(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum)
Definition quantized.h:145
void qvm_split_k(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &final_block_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1508
METAL_FUNC void qmv_fast_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:620
void qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1571
METAL_FUNC void adjust_matrix_offsets(const device T *&x, const device uint32_t *&w, const device T *&scales, const device T *&biases, device T *&y, int output_stride, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid)
Definition quantized.h:1211
void bs_qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1743
METAL_FUNC void qmv_quad_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:563
void qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1457
void qouter(const thread uint8_t *w, U x, U scale, U bias, thread U *result)
Definition quantized.h:307
void dequantize(const device uint8_t *w, U scale, U bias, threadgroup U *w_local)
Definition quantized.h:372
METAL_FUNC void qmm_t_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:956
Definition utils.h:23
Definition quantized.h:443
const int group_stride
Definition quantized.h:464
static constant constexpr const short BCOLS_PACKED
Definition quantized.h:456
const device T * biases
Definition quantized.h:473
short group_step_cnt
Definition quantized.h:463
static constant constexpr const short group_steps
Definition quantized.h:459
const short thread_idx
Definition quantized.h:466
QuantizedBlockLoader(const device uint8_t *src_, const device T *scales_, const device T *biases_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition quantized.h:475
const device T * scales
Definition quantized.h:472
static constant constexpr const short n_reads
Definition quantized.h:457
void next()
Definition quantized.h:541
void load_safe(short2 src_tile_dim) const
Definition quantized.h:511
const int src_ld
Definition quantized.h:461
const short bi
Definition quantized.h:467
void load_unsafe() const
Definition quantized.h:498
static constant constexpr const short pack_factor
Definition quantized.h:454
threadgroup T * dst
Definition quantized.h:470
const device uint8_t * src
Definition quantized.h:471
const int tile_stride
Definition quantized.h:462
static constant constexpr const short bytes_per_pack
Definition quantized.h:455
const short bj
Definition quantized.h:468
Definition loader.h:25