MLX
Loading...
Searching...
No Matches
quantized.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#include <metal_simdgroup>
4#include <metal_stdlib>
5
6using namespace metal;
7
8#define MLX_MTL_CONST static constant constexpr const
9
12
13template <typename T, typename U, int values_per_thread, int bits>
14inline U load_vector(const device T* x, thread U* x_thread) {
15 static_assert(
16 bits == 2 || bits == 4 || bits == 8,
17 "Template undefined for bits not in {2, 4, 8}");
18
19 U sum = 0;
20
21 if (bits == 2) {
22 for (int i = 0; i < values_per_thread; i += 4) {
23 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
24 x_thread[i] = x[i];
25 x_thread[i + 1] = x[i + 1] / 4.0f;
26 x_thread[i + 2] = x[i + 2] / 16.0f;
27 x_thread[i + 3] = x[i + 3] / 64.0f;
28 }
29 }
30
31 else if (bits == 4) {
32 for (int i = 0; i < values_per_thread; i += 4) {
33 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
34 x_thread[i] = x[i];
35 x_thread[i + 1] = x[i + 1] / 16.0f;
36 x_thread[i + 2] = x[i + 2] / 256.0f;
37 x_thread[i + 3] = x[i + 3] / 4096.0f;
38 }
39 }
40
41 else if (bits == 8) {
42 for (int i = 0; i < values_per_thread; i++) {
43 sum += x[i];
44 x_thread[i] = x[i];
45 }
46 }
47
48 return sum;
49}
50
51template <typename T, typename U, int values_per_thread, int bits>
52inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
53 static_assert(
54 bits == 2 || bits == 4 || bits == 8,
55 "Template undefined for bits not in {2, 4, 8}");
56
57 U sum = 0;
58
59 if (bits == 2) {
60 for (int i = 0; i < N; i += 4) {
61 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
62 x_thread[i] = x[i];
63 x_thread[i + 1] = x[i + 1] / 4.0f;
64 x_thread[i + 2] = x[i + 2] / 16.0f;
65 x_thread[i + 3] = x[i + 3] / 64.0f;
66 }
67 for (int i = N; i < values_per_thread; i++) {
68 x_thread[i] = 0;
69 }
70 }
71
72 else if (bits == 4) {
73 for (int i = 0; i < N; i += 4) {
74 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
75 x_thread[i] = x[i];
76 x_thread[i + 1] = x[i + 1] / 16.0f;
77 x_thread[i + 2] = x[i + 2] / 256.0f;
78 x_thread[i + 3] = x[i + 3] / 4096.0f;
79 }
80 for (int i = N; i < values_per_thread; i++) {
81 x_thread[i] = 0;
82 }
83 }
84
85 else if (bits == 8) {
86 for (int i = 0; i < N; i++) {
87 sum += x[i];
88 x_thread[i] = x[i];
89 }
90 for (int i = N; i < values_per_thread; i++) {
91 x_thread[i] = 0;
92 }
93 }
94
95 return sum;
96}
97
98template <typename U, int values_per_thread, int bits>
99inline U qdot(
100 const device uint8_t* w,
101 const thread U* x_thread,
102 U scale,
103 U bias,
104 U sum) {
105 static_assert(
106 bits == 2 || bits == 4 || bits == 8,
107 "Template undefined for bits not in {2, 4, 8}");
108
109 U accum = 0;
110
111 if (bits == 2) {
112 for (int i = 0; i < (values_per_thread / 4); i++) {
113 accum +=
114 (x_thread[4 * i] * (w[i] & 0x03) +
115 x_thread[4 * i + 1] * (w[i] & 0x0c) +
116 x_thread[4 * i + 2] * (w[i] & 0x30) +
117 x_thread[4 * i + 3] * (w[i] & 0xc0));
118 }
119 }
120
121 else if (bits == 4) {
122 const device uint16_t* ws = (const device uint16_t*)w;
123 for (int i = 0; i < (values_per_thread / 4); i++) {
124 accum +=
125 (x_thread[4 * i] * (ws[i] & 0x000f) +
126 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
127 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
128 x_thread[4 * i + 3] * (ws[i] & 0xf000));
129 }
130 }
131
132 else if (bits == 8) {
133 for (int i = 0; i < values_per_thread; i++) {
134 accum += x_thread[i] * w[i];
135 }
136 }
137
138 return scale * accum + sum * bias;
139}
140
141template <typename U, int values_per_thread, int bits>
142inline U qdot_safe(
143 const device uint8_t* w,
144 const thread U* x_thread,
145 U scale,
146 U bias,
147 U sum,
148 int N) {
149 static_assert(
150 bits == 2 || bits == 4 || bits == 8,
151 "Template undefined for bits not in {2, 4, 8}");
152
153 U accum = 0;
154
155 if (bits == 2) {
156 for (int i = 0; i < (N / 4); i++) {
157 accum +=
158 (x_thread[4 * i] * (w[i] & 0x03) +
159 x_thread[4 * i + 1] * (w[i] & 0x0c) +
160 x_thread[4 * i + 2] * (w[i] & 0x30) +
161 x_thread[4 * i + 3] * (w[i] & 0xc0));
162 }
163 }
164
165 else if (bits == 4) {
166 const device uint16_t* ws = (const device uint16_t*)w;
167 for (int i = 0; i < (N / 4); i++) {
168 accum +=
169 (x_thread[4 * i] * (ws[i] & 0x000f) +
170 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
171 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
172 x_thread[4 * i + 3] * (ws[i] & 0xf000));
173 }
174 }
175
176 else if (bits == 8) {
177 for (int i = 0; i < N; i++) {
178 accum += x_thread[i] * w[i];
179 }
180 }
181
182 return scale * accum + sum * bias;
183}
184
185template <typename U, int values_per_thread, int bits>
186inline void
187qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
188 static_assert(
189 bits == 2 || bits == 4 || bits == 8,
190 "Template undefined for bits not in {2, 4, 8}");
191
192 if (bits == 2) {
193 U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
194 for (int i = 0; i < (values_per_thread / 4); i++) {
195 result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
196 result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
197 result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
198 result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
199 }
200 }
201
202 else if (bits == 4) {
203 U s[2] = {scale, scale / 16.0f};
204 for (int i = 0; i < (values_per_thread / 2); i++) {
205 result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
206 result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
207 }
208 }
209
210 else if (bits == 8) {
211 for (int i = 0; i < values_per_thread; i++) {
212 result[i] += x * (scale * w[i] + bias);
213 }
214 }
215}
216
217template <typename U, int N, int bits>
218inline void
219dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
220 static_assert(
221 bits == 2 || bits == 4 || bits == 8,
222 "Template undefined for bits not in {2, 4, 8}");
223
224 if (bits == 2) {
225 U s[4] = {
226 scale,
227 scale / static_cast<U>(4.0f),
228 scale / static_cast<U>(16.0f),
229 scale / static_cast<U>(64.0f)};
230 for (int i = 0; i < (N / 4); i++) {
231 w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
232 w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
233 w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
234 w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
235 }
236 }
237
238 else if (bits == 4) {
239 U s[2] = {scale, scale / static_cast<U>(16.0f)};
240 for (int i = 0; i < (N / 2); i++) {
241 w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
242 w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
243 }
244 }
245
246 else if (bits == 8) {
247 for (int i = 0; i < N; i++) {
248 w_local[i] = scale * w[i] + bias;
249 }
250 }
251}
252
253template <
254 typename T,
255 short BROWS,
256 short BCOLS,
257 short dst_ld,
258 short reduction_dim,
259 short tgp_size,
260 short group_size,
261 short bits>
263 static_assert(
264 BCOLS <= group_size,
265 "The group size should be larger than the columns");
266 static_assert(
267 group_size % BCOLS == 0,
268 "The group size should be divisible by the columns");
269 static_assert(
270 bits == 2 || bits == 4 || bits == 8,
271 "Template undefined for bits not in {2, 4, 8}");
272
273 MLX_MTL_CONST short pack_factor = 32 / bits;
276 (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
277 MLX_MTL_CONST short group_steps = group_size / BCOLS;
278
279 const int src_ld;
280 const int tile_stride;
282 const int group_stride;
283
284 const short thread_idx;
285 const short bi;
286 const short bj;
287
288 threadgroup T* dst;
289 const device uint32_t* src;
290 const device T* scales;
291 const device T* biases;
292
294 const device uint32_t* src_,
295 const device T* scales_,
296 const device T* biases_,
297 const int src_ld_,
298 threadgroup T* dst_,
299 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
300 ushort simd_lane_id [[thread_index_in_simdgroup]])
301 : src_ld(src_ld_),
303 reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
305 group_stride(BROWS * src_ld / group_size),
306 thread_idx(simd_group_id * 32 + simd_lane_id),
309 dst(dst_ + bi * dst_ld + bj * pack_factor),
310 src(src_ + bi * src_ld / pack_factor + bj),
311 scales(scales_ + bi * src_ld / group_size),
312 biases(biases_ + bi * src_ld / group_size) {}
313
314 void load_unsafe() const {
315 if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
316 return;
317 }
318
319 T scale = *scales;
320 T bias = *biases;
321 for (int i = 0; i < n_reads; i++) {
323 (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
324 }
325 }
326
327 void load_safe(short2 src_tile_dim) const {
328 if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
329 return;
330 }
331
332 if (reduction_dim == 1 && bi >= src_tile_dim.y) {
333 for (int i = 0; i < n_reads * pack_factor; i++) {
334 dst[i] = T(0);
335 }
336 return;
337 }
338
339 if (reduction_dim == 0 && bi >= src_tile_dim.x) {
340 for (int i = 0; i < n_reads * pack_factor; i++) {
341 dst[i] = T(0);
342 }
343 return;
344 }
345
346 T scale = *scales;
347 T bias = *biases;
348 for (int i = 0; i < n_reads; i++) {
350 (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
351 }
352 }
353
354 void next() {
355 src += tile_stride;
356 if (reduction_dim == 1) {
357 if (group_steps > 1) {
360 group_step_cnt = 0;
361 scales++;
362 biases++;
363 }
364 } else {
365 scales++;
366 biases++;
367 }
368 } else {
371 }
372 }
373};
374
375template <typename T, int group_size, int bits, int D>
376METAL_FUNC void qmv_quad_impl(
377 const device uint32_t* w,
378 const device T* scales,
379 const device T* biases,
380 const device T* x,
381 device T* y,
382 constant int& in_vec_size,
383 const constant int& out_vec_size,
384 uint3 tid [[threadgroup_position_in_grid]],
385 uint quad_gid [[quadgroup_index_in_threadgroup]],
386 uint quad_lid [[thread_index_in_quadgroup]]) {
387 constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
388 constexpr int pack_factor = 32 / bits;
389 constexpr int values_per_thread = D / QUAD_SIZE;
390 constexpr int packs_per_thread = values_per_thread / pack_factor;
391 constexpr int scale_step_per_thread = group_size / values_per_thread;
392 constexpr int results_per_quadgroup = 8;
393
394 typedef float U;
395
396 thread U x_thread[values_per_thread];
397 thread U result[results_per_quadgroup] = {0};
398
399 // Adjust positions
400 const int in_vec_size_w = in_vec_size / pack_factor;
401 const int in_vec_size_g = in_vec_size / group_size;
402 const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
403
404 w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
405 scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
406 biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
407 x += tid.y * in_vec_size + quad_lid * values_per_thread;
408 y += tid.y * out_vec_size + out_row;
409
411
412 for (int row = 0; row < results_per_quadgroup; row++) {
413 const device uint8_t* wl =
414 (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
415 const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
416 const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
417
418 U s = sl[0];
419 U b = bl[0];
420 if (row * quads_per_simd + out_row < out_vec_size) {
421 result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
422 }
423 }
424
425 for (int row = 0; row < results_per_quadgroup; row++) {
426 result[row] = quad_sum(result[row]);
427 if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
428 y[row * quads_per_simd] = static_cast<T>(result[row]);
429 }
430 }
431}
432
433template <typename T, int group_size, int bits>
434METAL_FUNC void qmv_fast_impl(
435 const device uint32_t* w,
436 const device T* scales,
437 const device T* biases,
438 const device T* x,
439 device T* y,
440 const constant int& in_vec_size,
441 const constant int& out_vec_size,
442 uint3 tid [[threadgroup_position_in_grid]],
443 uint simd_gid [[simdgroup_index_in_threadgroup]],
444 uint simd_lid [[thread_index_in_simdgroup]]) {
445 constexpr int packs_per_thread = bits > 2 ? 2 : 1;
446 constexpr int num_simdgroups = 2;
447 constexpr int results_per_simdgroup = 4;
448 constexpr int pack_factor = 32 / bits;
449 constexpr int values_per_thread = pack_factor * packs_per_thread;
450 constexpr int block_size = values_per_thread * SIMD_SIZE;
451 constexpr int scale_step_per_thread = group_size / values_per_thread;
452
453 typedef float U;
454
455 thread U x_thread[values_per_thread];
456 thread U result[results_per_simdgroup] = {0};
457
458 // Adjust positions
459 const int in_vec_size_w = in_vec_size / pack_factor;
460 const int in_vec_size_g = in_vec_size / group_size;
461 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
462 simd_gid * results_per_simdgroup;
463 w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
464 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
465 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
466 x += tid.y * in_vec_size + simd_lid * values_per_thread;
467 y += tid.y * out_vec_size + out_row;
468
469 for (int k = 0; k < in_vec_size; k += block_size) {
471
472 for (int row = 0; row < results_per_simdgroup; row++) {
473 const device uint8_t* wl =
474 (const device uint8_t*)(w + row * in_vec_size_w);
475 const device T* sl = scales + row * in_vec_size_g;
476 const device T* bl = biases + row * in_vec_size_g;
477
478 U s = sl[0];
479 U b = bl[0];
480 result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
481 }
482
483 w += block_size / pack_factor;
484 scales += block_size / group_size;
485 biases += block_size / group_size;
486 x += block_size;
487 }
488
489 for (int row = 0; row < results_per_simdgroup; row++) {
490 result[row] = simd_sum(result[row]);
491 if (simd_lid == 0) {
492 y[row] = static_cast<T>(result[row]);
493 }
494 }
495}
496
497template <typename T, int group_size, int bits>
498METAL_FUNC void qmv_impl(
499 const device uint32_t* w,
500 const device T* scales,
501 const device T* biases,
502 const device T* x,
503 device T* y,
504 const constant int& in_vec_size,
505 const constant int& out_vec_size,
506 uint3 tid [[threadgroup_position_in_grid]],
507 uint simd_gid [[simdgroup_index_in_threadgroup]],
508 uint simd_lid [[thread_index_in_simdgroup]]) {
509 constexpr int num_simdgroups = 2;
510 constexpr int results_per_simdgroup = 4;
511 constexpr int packs_per_thread = 1;
512 constexpr int pack_factor = 32 / bits;
513 constexpr int values_per_thread = pack_factor * packs_per_thread;
514 constexpr int block_size = values_per_thread * SIMD_SIZE;
515 constexpr int scale_step_per_thread = group_size / values_per_thread;
516
517 typedef float U;
518
519 thread U x_thread[values_per_thread];
520 thread U result[results_per_simdgroup] = {0};
521
522 // Adjust positions
523 const int in_vec_size_w = in_vec_size / pack_factor;
524 const int in_vec_size_g = in_vec_size / group_size;
525 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
526 simd_gid * results_per_simdgroup;
527 const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
528
529 if (out_row >= out_vec_size) {
530 return;
531 }
532
533 // In this case we need to properly guard all our reads because there isn't
534 // even 1 tile in the matrix
535 if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
536 w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
537 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
538 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
539 x += tid.y * in_vec_size + simd_lid * values_per_thread;
540 y += tid.y * out_vec_size + out_row;
541
542 int k = 0;
543 for (; k < in_vec_size - block_size; k += block_size) {
545
546 for (int row = 0; out_row + row < out_vec_size; row++) {
547 const device uint8_t* wl =
548 (const device uint8_t*)(w + row * in_vec_size_w);
549 const device T* sl = scales + row * in_vec_size_g;
550 const device T* bl = biases + row * in_vec_size_g;
551
552 U s = sl[0];
553 U b = bl[0];
554 result[row] +=
555 qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
556 }
557
558 w += block_size / pack_factor;
559 scales += block_size / group_size;
560 biases += block_size / group_size;
561 x += block_size;
562 }
563 const int remaining = clamp(
564 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
565 0,
566 values_per_thread);
567 U sum =
569
570 for (int row = 0; out_row + row < out_vec_size; row++) {
571 const device uint8_t* wl =
572 (const device uint8_t*)(w + row * in_vec_size_w);
573 const device T* sl = scales + row * in_vec_size_g;
574 const device T* bl = biases + row * in_vec_size_g;
575
576 U s = sl[0];
577 U b = bl[0];
578 result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
579 }
580
581 for (int row = 0; out_row + row < out_vec_size; row++) {
582 result[row] = simd_sum(result[row]);
583 if (simd_lid == 0) {
584 y[row] = static_cast<T>(result[row]);
585 }
586 }
587 }
588
589 // In this case the last tile is moved back to redo some output values
590 else {
591 w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread;
592 scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
593 biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
594 x += tid.y * in_vec_size + simd_lid * values_per_thread;
595 y += tid.y * out_vec_size + used_out_row;
596
597 int k = 0;
598 for (; k < in_vec_size - block_size; k += block_size) {
600
601 for (int row = 0; row < results_per_simdgroup; row++) {
602 const device uint8_t* wl =
603 (const device uint8_t*)(w + row * in_vec_size_w);
604 const device T* sl = scales + row * in_vec_size_g;
605 const device T* bl = biases + row * in_vec_size_g;
606
607 U s = sl[0];
608 U b = bl[0];
609 result[row] +=
610 qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
611 }
612
613 w += block_size / pack_factor;
614 scales += block_size / group_size;
615 biases += block_size / group_size;
616 x += block_size;
617 }
618 const int remaining = clamp(
619 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
620 0,
621 values_per_thread);
622 U sum =
624
625 for (int row = 0; row < results_per_simdgroup; row++) {
626 const device uint8_t* wl =
627 (const device uint8_t*)(w + row * in_vec_size_w);
628 const device T* sl = scales + row * in_vec_size_g;
629 const device T* bl = biases + row * in_vec_size_g;
630
631 U s = sl[0];
632 U b = bl[0];
634 wl, x_thread, s, b, sum, remaining);
635 }
636
637 for (int row = 0; row < results_per_simdgroup; row++) {
638 result[row] = simd_sum(result[row]);
639 if (simd_lid == 0) {
640 y[row] = static_cast<T>(result[row]);
641 }
642 }
643 }
644}
645
646template <typename T, const int group_size, const int bits>
647METAL_FUNC void qvm_impl(
648 const device uint32_t* w,
649 const device T* scales,
650 const device T* biases,
651 const device T* x,
652 device T* y,
653 const int in_vec_size,
654 const int out_vec_size,
655 uint3 tid [[threadgroup_position_in_grid]],
656 uint simd_gid [[simdgroup_index_in_threadgroup]],
657 uint simd_lid [[thread_index_in_simdgroup]]) {
658 constexpr int num_simdgroups = 2;
659 constexpr int pack_factor = 32 / bits;
660 constexpr int tn = 32 / pack_factor;
661 constexpr int blocksize = SIMD_SIZE;
662
663 typedef float U;
664 typedef struct {
665 uint32_t wi[tn];
666 } vec_w;
667
668 thread vec_w w_local;
669 thread U result[tn * pack_factor] = {0};
670 thread U scale = 1;
671 thread U bias = 0;
672 thread U x_local = 0;
673
674 // Adjust positions
675 const int out_vec_size_w = out_vec_size / pack_factor;
676 const int out_vec_size_g = out_vec_size / group_size;
677 int out_col =
678 tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
679 w += out_col / pack_factor + simd_lid * out_vec_size_w;
680 scales += out_col / group_size + simd_lid * out_vec_size_g;
681 biases += out_col / group_size + simd_lid * out_vec_size_g;
682 x += tid.y * in_vec_size + simd_lid;
683 y += tid.y * out_vec_size + out_col;
684
685 if (out_col >= out_vec_size) {
686 return;
687 }
688
689 // Loop over in_vec in blocks of blocksize
690 int remaining = in_vec_size % blocksize;
691 if (remaining == 0) {
692 for (int i = 0; i < in_vec_size; i += blocksize) {
693 x_local = *x;
694 scale = *scales;
695 bias = *biases;
696 w_local = *((device vec_w*)w);
697
699 (thread uint8_t*)&w_local, x_local, scale, bias, result);
700
701 x += blocksize;
702 scales += blocksize * out_vec_size_g;
703 biases += blocksize * out_vec_size_g;
704 w += blocksize * out_vec_size_w;
705 }
706 } else {
707 for (int i = blocksize; i < in_vec_size; i += blocksize) {
708 x_local = *x;
709 scale = *scales;
710 bias = *biases;
711 w_local = *((device vec_w*)w);
712
714 (thread uint8_t*)&w_local, x_local, scale, bias, result);
715
716 x += blocksize;
717 scales += blocksize * out_vec_size_g;
718 biases += blocksize * out_vec_size_g;
719 w += blocksize * out_vec_size_w;
720 }
721 if (static_cast<int>(simd_lid) < remaining) {
722 x_local = *x;
723 scale = *scales;
724 bias = *biases;
725 w_local = *((device vec_w*)w);
726 } else {
727 x_local = 0;
728 scale = 0;
729 bias = 0;
730 }
732 (thread uint8_t*)&w_local, x_local, scale, bias, result);
733 }
734
735// Accumulate in the simdgroup
736#pragma clang loop unroll(full)
737 for (int k = 0; k < tn * pack_factor; k++) {
738 result[k] = simd_sum(result[k]);
739 }
740
741 // Store the result
742 if (simd_lid == 0) {
743#pragma clang loop unroll(full)
744 for (int k = 0; k < tn * pack_factor; k++) {
745 y[k] = static_cast<T>(result[k]);
746 }
747 }
748}
749
750template <
751 typename T,
752 const int group_size,
753 const int bits,
754 const bool aligned_N,
755 const int BM = 32,
756 const int BK = 32,
757 const int BN = 32>
758METAL_FUNC void qmm_t_impl(
759 const device uint32_t* w,
760 const device T* scales,
761 const device T* biases,
762 const device T* x,
763 device T* y,
764 threadgroup T* Xs,
765 threadgroup T* Ws,
766 const constant int& K,
767 const constant int& N,
768 const constant int& M,
769 uint3 tid [[threadgroup_position_in_grid]],
770 uint lid [[thread_index_in_threadgroup]],
771 uint simd_gid [[simdgroup_index_in_threadgroup]],
772 uint simd_lid [[thread_index_in_simdgroup]]) {
773 static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
774 static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
775
776 (void)lid;
777
778 constexpr int WM = 2;
779 constexpr int WN = 2;
780 constexpr int pack_factor = 32 / bits;
781 constexpr int BK_padded = (BK + 16 / sizeof(T));
782
783 // Instantiate the appropriate BlockMMA and Loader
784 using mma_t = mlx::steel::
785 BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
786 using loader_x_t =
788 using loader_w_t = QuantizedBlockLoader<
789 T,
790 BN,
791 BK,
792 BK_padded,
793 1,
794 WM * WN * SIMD_SIZE,
795 group_size,
796 bits>;
797
798 // Set the block
799 const int K_w = K / pack_factor;
800 const int K_g = K / group_size;
801 const int y_row = tid.y * BM;
802 const int y_col = tid.x * BN;
803
804 x += y_row * K;
805 w += y_col * K_w;
806 scales += y_col * K_g;
807 biases += y_col * K_g;
808 y += y_row * N + y_col;
809
810 // Make the x loader and mma operation
811 const short num_els = min(BM, M - y_row);
812 const short num_outs = min(BN, N - y_col);
813 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
814 loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid);
815 mma_t mma_op(simd_gid, simd_lid);
816
817 if (num_els < BM) {
818 if (!aligned_N && num_outs < BN) {
819 for (int k = 0; k < K; k += BK) {
820 threadgroup_barrier(mem_flags::mem_threadgroup);
821 loader_x.load_safe(short2(BK, num_els));
822 loader_w.load_safe(short2(BK, num_outs));
823 threadgroup_barrier(mem_flags::mem_threadgroup);
824 mma_op.mma(Xs, Ws);
825 loader_x.next();
826 loader_w.next();
827 }
828 } else {
829 for (int k = 0; k < K; k += BK) {
830 threadgroup_barrier(mem_flags::mem_threadgroup);
831 loader_x.load_safe(short2(BK, num_els));
832 loader_w.load_unsafe();
833 threadgroup_barrier(mem_flags::mem_threadgroup);
834 mma_op.mma(Xs, Ws);
835 loader_x.next();
836 loader_w.next();
837 }
838 }
839 } else {
840 if (!aligned_N && num_outs < BN) {
841 for (int k = 0; k < K; k += BK) {
842 threadgroup_barrier(mem_flags::mem_threadgroup);
843 loader_x.load_unsafe();
844 loader_w.load_safe(short2(BK, num_outs));
845 threadgroup_barrier(mem_flags::mem_threadgroup);
846 mma_op.mma(Xs, Ws);
847 loader_x.next();
848 loader_w.next();
849 }
850 } else {
851 for (int k = 0; k < K; k += BK) {
852 threadgroup_barrier(mem_flags::mem_threadgroup);
853 loader_x.load_unsafe();
854 loader_w.load_unsafe();
855 threadgroup_barrier(mem_flags::mem_threadgroup);
856 mma_op.mma(Xs, Ws);
857 loader_x.next();
858 loader_w.next();
859 }
860 }
861 }
862
863 // Store results to device memory
864 threadgroup_barrier(mem_flags::mem_threadgroup);
865 if (num_els < BM || num_outs < BN) {
866 mma_op.store_result_safe(y, N, short2(num_outs, num_els));
867 } else {
868 mma_op.store_result(y, N);
869 }
870}
871
872template <
873 typename T,
874 const int group_size,
875 const int bits,
876 const int BM = 32,
877 const int BK = 32,
878 const int BN = 32>
879METAL_FUNC void qmm_n_impl(
880 const device uint32_t* w,
881 const device T* scales,
882 const device T* biases,
883 const device T* x,
884 device T* y,
885 threadgroup T* Xs,
886 threadgroup T* Ws,
887 const constant int& K,
888 const constant int& N,
889 const constant int& M,
890 uint3 tid [[threadgroup_position_in_grid]],
891 uint lid [[thread_index_in_threadgroup]],
892 uint simd_gid [[simdgroup_index_in_threadgroup]],
893 uint simd_lid [[thread_index_in_simdgroup]]) {
894 static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
895 static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
896
897 (void)lid;
898
899 constexpr int WM = 2;
900 constexpr int WN = 2;
901 constexpr int pack_factor = 32 / bits;
902 constexpr int BK_padded = (BK + 16 / sizeof(T));
903 constexpr int BN_padded = (BN + 16 / sizeof(T));
904
905 // Instantiate the appropriate BlockMMA and Loader
906 using mma_t = mlx::steel::
907 BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
908 using loader_x_t = mlx::steel::
909 BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
910 using loader_w_t = QuantizedBlockLoader<
911 T,
912 BK,
913 BN,
914 BN_padded,
915 0,
916 WM * WN * SIMD_SIZE,
917 group_size,
918 bits>;
919
920 // Set the block
921 const int y_row = tid.y * BM;
922 const int y_col = tid.x * BN;
923 x += y_row * K;
924 w += y_col / pack_factor;
925 scales += y_col / group_size;
926 biases += y_col / group_size;
927 y += y_row * N + y_col;
928
929 // Make the x loader and mma operation
930 const short num_els = min(BM, M - y_row);
931 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
932 loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid);
933 mma_t mma_op(simd_gid, simd_lid);
934
935 if (num_els < BM) {
936 if ((K % BK) != 0) {
937 const int k_blocks = K / BK;
938 for (int k = 0; k < k_blocks; k++) {
939 threadgroup_barrier(mem_flags::mem_threadgroup);
940 loader_x.load_safe(short2(BK, num_els));
941 loader_w.load_unsafe();
942 threadgroup_barrier(mem_flags::mem_threadgroup);
943 mma_op.mma(Xs, Ws);
944 loader_x.next();
945 loader_w.next();
946 }
947 const short num_k = K - k_blocks * BK;
948 threadgroup_barrier(mem_flags::mem_threadgroup);
949 loader_x.load_safe(short2(num_k, num_els));
950 loader_w.load_safe(short2(BN, num_k));
951 threadgroup_barrier(mem_flags::mem_threadgroup);
952 mma_op.mma(Xs, Ws);
953 } else {
954 for (int k = 0; k < K; k += BK) {
955 threadgroup_barrier(mem_flags::mem_threadgroup);
956 loader_x.load_safe(short2(BK, num_els));
957 loader_w.load_unsafe();
958 threadgroup_barrier(mem_flags::mem_threadgroup);
959 mma_op.mma(Xs, Ws);
960 loader_x.next();
961 loader_w.next();
962 }
963 }
964 } else {
965 if ((K % BK) != 0) {
966 const int k_blocks = K / BK;
967 for (int k = 0; k < k_blocks; k++) {
968 threadgroup_barrier(mem_flags::mem_threadgroup);
969 loader_x.load_unsafe();
970 loader_w.load_unsafe();
971 threadgroup_barrier(mem_flags::mem_threadgroup);
972 mma_op.mma(Xs, Ws);
973 loader_x.next();
974 loader_w.next();
975 }
976 const short num_k = K - k_blocks * BK;
977 threadgroup_barrier(mem_flags::mem_threadgroup);
978 loader_x.load_safe(short2(num_k, BM));
979 loader_w.load_safe(short2(BN, num_k));
980 threadgroup_barrier(mem_flags::mem_threadgroup);
981 mma_op.mma(Xs, Ws);
982 } else {
983 for (int k = 0; k < K; k += BK) {
984 threadgroup_barrier(mem_flags::mem_threadgroup);
985 loader_x.load_unsafe();
986 loader_w.load_unsafe();
987 threadgroup_barrier(mem_flags::mem_threadgroup);
988 mma_op.mma(Xs, Ws);
989 loader_x.next();
990 loader_w.next();
991 }
992 }
993 }
994
995 // Store results to device memory
996 threadgroup_barrier(mem_flags::mem_threadgroup);
997 if (num_els < BM) {
998 mma_op.store_result_safe(y, N, short2(BN, num_els));
999 } else {
1000 mma_op.store_result(y, N);
1001 }
1002}
1003
1004template <typename T>
1006 const device T*& x,
1007 const device uint32_t*& w,
1008 const device T*& scales,
1009 const device T*& biases,
1010 device T*& y,
1011 int output_stride,
1012 const constant int& x_batch_ndims,
1013 const constant int* x_shape,
1014 const constant size_t* x_strides,
1015 const constant int& w_batch_ndims,
1016 const constant int* w_shape,
1017 const constant size_t* w_strides,
1018 const constant size_t* s_strides,
1019 const constant size_t* b_strides,
1020 uint3 tid [[threadgroup_position_in_grid]]) {
1021 // Set the input/output matrices
1022 uint32_t x_idx = tid.z;
1023 uint32_t w_idx = tid.z;
1024 if (x_batch_ndims == 1) {
1025 x += x_idx * x_strides[0];
1026 } else {
1027 x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1028 }
1029 if (w_batch_ndims == 1) {
1030 w += w_idx * w_strides[0];
1031 scales += w_idx * s_strides[0];
1032 biases += w_idx * b_strides[0];
1033 } else {
1034 ulong3 idx = elem_to_loc_broadcast(
1035 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1036 w += idx.x;
1037 scales += idx.y;
1038 biases += idx.z;
1039 }
1040 y += tid.z * output_stride;
1041}
1042
1043template <typename T>
1045 const device T*& x,
1046 const device uint32_t*& w,
1047 const device T*& scales,
1048 const device T*& biases,
1049 const device uint32_t* lhs_indices,
1050 const device uint32_t* rhs_indices,
1051 device T*& y,
1052 int output_stride,
1053 const constant int& batch_ndims,
1054 const constant int* batch_shape,
1055 const constant size_t* lhs_strides,
1056 const constant size_t* rhs_strides,
1057 const constant int& x_batch_ndims,
1058 const constant int* x_shape,
1059 const constant size_t* x_strides,
1060 const constant int& w_batch_ndims,
1061 const constant int* w_shape,
1062 const constant size_t* w_strides,
1063 const constant size_t* s_strides,
1064 const constant size_t* b_strides,
1065 uint3 tid [[threadgroup_position_in_grid]]) {
1066 // Set the input/output matrices
1067 uint32_t x_idx;
1068 uint32_t w_idx;
1069 if (batch_ndims == 1) {
1070 x_idx = lhs_indices[tid.z * lhs_strides[0]];
1071 w_idx = rhs_indices[tid.z * rhs_strides[0]];
1072 } else {
1073 ulong2 idx = elem_to_loc_broadcast(
1074 tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
1075 x_idx = lhs_indices[idx.x];
1076 w_idx = rhs_indices[idx.y];
1077 }
1078 if (x_batch_ndims == 1) {
1079 x += x_idx * x_strides[0];
1080 } else {
1081 x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1082 }
1083 if (w_batch_ndims == 1) {
1084 w += w_idx * w_strides[0];
1085 scales += w_idx * s_strides[0];
1086 biases += w_idx * b_strides[0];
1087 } else {
1088 ulong3 idx = elem_to_loc_broadcast(
1089 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1090 w += idx.x;
1091 scales += idx.y;
1092 biases += idx.z;
1093 }
1094 y += tid.z * output_stride;
1095}
1096
1097template <typename T, int group_size, int bits, int D, bool batched>
1098[[kernel]] void qmv_quad(
1099 const device uint32_t* w [[buffer(0)]],
1100 const device T* scales [[buffer(1)]],
1101 const device T* biases [[buffer(2)]],
1102 const device T* x [[buffer(3)]],
1103 device T* y [[buffer(4)]],
1104 const constant int& in_vec_size [[buffer(5)]],
1105 const constant int& out_vec_size [[buffer(6)]],
1106 const constant int& x_batch_ndims [[buffer(7)]],
1107 const constant int* x_shape [[buffer(8)]],
1108 const constant size_t* x_strides [[buffer(9)]],
1109 const constant int& w_batch_ndims [[buffer(10)]],
1110 const constant int* w_shape [[buffer(11)]],
1111 const constant size_t* w_strides [[buffer(12)]],
1112 const constant size_t* s_strides [[buffer(13)]],
1113 const constant size_t* b_strides [[buffer(14)]],
1114 uint3 tid [[threadgroup_position_in_grid]],
1115 uint quad_gid [[quadgroup_index_in_threadgroup]],
1116 uint quad_lid [[thread_index_in_quadgroup]]) {
1117 if (batched) {
1119 x,
1120 w,
1121 scales,
1122 biases,
1123 y,
1124 out_vec_size,
1125 x_batch_ndims,
1126 x_shape,
1127 x_strides,
1128 w_batch_ndims,
1129 w_shape,
1130 w_strides,
1131 s_strides,
1132 b_strides,
1133 tid);
1134 }
1136 w,
1137 scales,
1138 biases,
1139 x,
1140 y,
1141 in_vec_size,
1142 out_vec_size,
1143 tid,
1144 quad_gid,
1145 quad_lid);
1146}
1147
1148template <typename T, int group_size, int bits, bool batched>
1149[[kernel]] void qmv_fast(
1150 const device uint32_t* w [[buffer(0)]],
1151 const device T* scales [[buffer(1)]],
1152 const device T* biases [[buffer(2)]],
1153 const device T* x [[buffer(3)]],
1154 device T* y [[buffer(4)]],
1155 const constant int& in_vec_size [[buffer(5)]],
1156 const constant int& out_vec_size [[buffer(6)]],
1157 const constant int& x_batch_ndims [[buffer(7)]],
1158 const constant int* x_shape [[buffer(8)]],
1159 const constant size_t* x_strides [[buffer(9)]],
1160 const constant int& w_batch_ndims [[buffer(10)]],
1161 const constant int* w_shape [[buffer(11)]],
1162 const constant size_t* w_strides [[buffer(12)]],
1163 const constant size_t* s_strides [[buffer(13)]],
1164 const constant size_t* b_strides [[buffer(14)]],
1165 uint3 tid [[threadgroup_position_in_grid]],
1166 uint simd_gid [[simdgroup_index_in_threadgroup]],
1167 uint simd_lid [[thread_index_in_simdgroup]]) {
1168 if (batched) {
1170 x,
1171 w,
1172 scales,
1173 biases,
1174 y,
1175 out_vec_size,
1176 x_batch_ndims,
1177 x_shape,
1178 x_strides,
1179 w_batch_ndims,
1180 w_shape,
1181 w_strides,
1182 s_strides,
1183 b_strides,
1184 tid);
1185 }
1187 w,
1188 scales,
1189 biases,
1190 x,
1191 y,
1192 in_vec_size,
1193 out_vec_size,
1194 tid,
1195 simd_gid,
1196 simd_lid);
1197}
1198
1199template <typename T, const int group_size, const int bits, bool batched>
1200[[kernel]] void qmv(
1201 const device uint32_t* w [[buffer(0)]],
1202 const device T* scales [[buffer(1)]],
1203 const device T* biases [[buffer(2)]],
1204 const device T* x [[buffer(3)]],
1205 device T* y [[buffer(4)]],
1206 const constant int& in_vec_size [[buffer(5)]],
1207 const constant int& out_vec_size [[buffer(6)]],
1208 const constant int& x_batch_ndims [[buffer(7)]],
1209 const constant int* x_shape [[buffer(8)]],
1210 const constant size_t* x_strides [[buffer(9)]],
1211 const constant int& w_batch_ndims [[buffer(10)]],
1212 const constant int* w_shape [[buffer(11)]],
1213 const constant size_t* w_strides [[buffer(12)]],
1214 const constant size_t* s_strides [[buffer(13)]],
1215 const constant size_t* b_strides [[buffer(14)]],
1216 uint3 tid [[threadgroup_position_in_grid]],
1217 uint simd_gid [[simdgroup_index_in_threadgroup]],
1218 uint simd_lid [[thread_index_in_simdgroup]]) {
1219 if (batched) {
1221 x,
1222 w,
1223 scales,
1224 biases,
1225 y,
1226 out_vec_size,
1227 x_batch_ndims,
1228 x_shape,
1229 x_strides,
1230 w_batch_ndims,
1231 w_shape,
1232 w_strides,
1233 s_strides,
1234 b_strides,
1235 tid);
1236 }
1238 w,
1239 scales,
1240 biases,
1241 x,
1242 y,
1243 in_vec_size,
1244 out_vec_size,
1245 tid,
1246 simd_gid,
1247 simd_lid);
1248}
1249
1250template <typename T, const int group_size, const int bits, bool batched>
1251[[kernel]] void qvm(
1252 const device uint32_t* w [[buffer(0)]],
1253 const device T* scales [[buffer(1)]],
1254 const device T* biases [[buffer(2)]],
1255 const device T* x [[buffer(3)]],
1256 device T* y [[buffer(4)]],
1257 const constant int& in_vec_size [[buffer(5)]],
1258 const constant int& out_vec_size [[buffer(6)]],
1259 const constant int& x_batch_ndims [[buffer(7)]],
1260 const constant int* x_shape [[buffer(8)]],
1261 const constant size_t* x_strides [[buffer(9)]],
1262 const constant int& w_batch_ndims [[buffer(10)]],
1263 const constant int* w_shape [[buffer(11)]],
1264 const constant size_t* w_strides [[buffer(12)]],
1265 const constant size_t* s_strides [[buffer(13)]],
1266 const constant size_t* b_strides [[buffer(14)]],
1267 uint3 tid [[threadgroup_position_in_grid]],
1268 uint simd_gid [[simdgroup_index_in_threadgroup]],
1269 uint simd_lid [[thread_index_in_simdgroup]]) {
1270 if (batched) {
1272 x,
1273 w,
1274 scales,
1275 biases,
1276 y,
1277 out_vec_size,
1278 x_batch_ndims,
1279 x_shape,
1280 x_strides,
1281 w_batch_ndims,
1282 w_shape,
1283 w_strides,
1284 s_strides,
1285 b_strides,
1286 tid);
1287 }
1289 w,
1290 scales,
1291 biases,
1292 x,
1293 y,
1294 in_vec_size,
1295 out_vec_size,
1296 tid,
1297 simd_gid,
1298 simd_lid);
1299}
1300
1301template <typename T, const int group_size, const int bits, int split_k = 32>
1302[[kernel]] void qvm_split_k(
1303 const device uint32_t* w [[buffer(0)]],
1304 const device T* scales [[buffer(1)]],
1305 const device T* biases [[buffer(2)]],
1306 const device T* x [[buffer(3)]],
1307 device T* y [[buffer(4)]],
1308 const constant int& in_vec_size [[buffer(5)]],
1309 const constant int& out_vec_size [[buffer(6)]],
1310 const constant int& x_batch_ndims [[buffer(7)]],
1311 const constant int* x_shape [[buffer(8)]],
1312 const constant size_t* x_strides [[buffer(9)]],
1313 const constant int& w_batch_ndims [[buffer(10)]],
1314 const constant int* w_shape [[buffer(11)]],
1315 const constant size_t* w_strides [[buffer(12)]],
1316 const constant size_t* s_strides [[buffer(13)]],
1317 const constant size_t* b_strides [[buffer(14)]],
1318 const constant int& final_block_size [[buffer(15)]],
1319 uint3 tid [[threadgroup_position_in_grid]],
1320 uint simd_gid [[simdgroup_index_in_threadgroup]],
1321 uint simd_lid [[thread_index_in_simdgroup]]) {
1323 x,
1324 w,
1325 scales,
1326 biases,
1327 y,
1328 out_vec_size,
1329 x_batch_ndims,
1330 x_shape,
1331 x_strides,
1332 w_batch_ndims,
1333 w_shape,
1334 w_strides,
1335 s_strides,
1336 b_strides,
1337 tid);
1338
1339 // When (in_vec_size % split_k != 0) the final block needs to be smaller
1340 int in_vec_size_adj =
1341 tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
1342
1344 w,
1345 scales,
1346 biases,
1347 x,
1348 y,
1349 in_vec_size_adj,
1350 out_vec_size,
1351 tid,
1352 simd_gid,
1353 simd_lid);
1354}
1355
1356template <
1357 typename T,
1358 const int group_size,
1359 const int bits,
1360 const bool aligned_N,
1361 const bool batched,
1362 const int BM = 32,
1363 const int BK = 32,
1364 const int BN = 32>
1365[[kernel]] void qmm_t(
1366 const device uint32_t* w [[buffer(0)]],
1367 const device T* scales [[buffer(1)]],
1368 const device T* biases [[buffer(2)]],
1369 const device T* x [[buffer(3)]],
1370 device T* y [[buffer(4)]],
1371 const constant int& K [[buffer(5)]],
1372 const constant int& N [[buffer(6)]],
1373 const constant int& M [[buffer(7)]],
1374 const constant int& x_batch_ndims [[buffer(8)]],
1375 const constant int* x_shape [[buffer(9)]],
1376 const constant size_t* x_strides [[buffer(10)]],
1377 const constant int& w_batch_ndims [[buffer(11)]],
1378 const constant int* w_shape [[buffer(12)]],
1379 const constant size_t* w_strides [[buffer(13)]],
1380 const constant size_t* s_strides [[buffer(14)]],
1381 const constant size_t* b_strides [[buffer(15)]],
1382 uint3 tid [[threadgroup_position_in_grid]],
1383 uint lid [[thread_index_in_threadgroup]],
1384 uint simd_gid [[simdgroup_index_in_threadgroup]],
1385 uint simd_lid [[thread_index_in_simdgroup]]) {
1386 (void)lid;
1387
1388 constexpr int BK_padded = (BK + 16 / sizeof(T));
1389
1390 threadgroup T Xs[BM * BK_padded];
1391 threadgroup T Ws[BN * BK_padded];
1392
1393 if (batched) {
1395 x,
1396 w,
1397 scales,
1398 biases,
1399 y,
1400 M * N,
1401 x_batch_ndims,
1402 x_shape,
1403 x_strides,
1404 w_batch_ndims,
1405 w_shape,
1406 w_strides,
1407 s_strides,
1408 b_strides,
1409 tid);
1410 }
1412 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1413}
1414
1415template <
1416 typename T,
1417 const int group_size,
1418 const int bits,
1419 const bool batched,
1420 const int BM = 32,
1421 const int BK = 32,
1422 const int BN = 32>
1423[[kernel]] void qmm_n(
1424 const device uint32_t* w [[buffer(0)]],
1425 const device T* scales [[buffer(1)]],
1426 const device T* biases [[buffer(2)]],
1427 const device T* x [[buffer(3)]],
1428 device T* y [[buffer(4)]],
1429 const constant int& K [[buffer(5)]],
1430 const constant int& N [[buffer(6)]],
1431 const constant int& M [[buffer(7)]],
1432 const constant int& x_batch_ndims [[buffer(8)]],
1433 const constant int* x_shape [[buffer(9)]],
1434 const constant size_t* x_strides [[buffer(10)]],
1435 const constant int& w_batch_ndims [[buffer(11)]],
1436 const constant int* w_shape [[buffer(12)]],
1437 const constant size_t* w_strides [[buffer(13)]],
1438 const constant size_t* s_strides [[buffer(14)]],
1439 const constant size_t* b_strides [[buffer(15)]],
1440 uint3 tid [[threadgroup_position_in_grid]],
1441 uint lid [[thread_index_in_threadgroup]],
1442 uint simd_gid [[simdgroup_index_in_threadgroup]],
1443 uint simd_lid [[thread_index_in_simdgroup]]) {
1444 (void)lid;
1445
1446 constexpr int BK_padded = (BK + 16 / sizeof(T));
1447 constexpr int BN_padded = (BN + 16 / sizeof(T));
1448
1449 threadgroup T Xs[BM * BK_padded];
1450 threadgroup T Ws[BK * BN_padded];
1451
1452 if (batched) {
1454 x,
1455 w,
1456 scales,
1457 biases,
1458 y,
1459 M * N,
1460 x_batch_ndims,
1461 x_shape,
1462 x_strides,
1463 w_batch_ndims,
1464 w_shape,
1465 w_strides,
1466 s_strides,
1467 b_strides,
1468 tid);
1469 }
1470
1472 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1473}
1474
1475template <typename T, int group_size, int bits>
1476[[kernel]] void bs_qmv_fast(
1477 const device uint32_t* w [[buffer(0)]],
1478 const device T* scales [[buffer(1)]],
1479 const device T* biases [[buffer(2)]],
1480 const device T* x [[buffer(3)]],
1481 device T* y [[buffer(4)]],
1482 const constant int& in_vec_size [[buffer(5)]],
1483 const constant int& out_vec_size [[buffer(6)]],
1484 const constant int& x_batch_ndims [[buffer(7)]],
1485 const constant int* x_shape [[buffer(8)]],
1486 const constant size_t* x_strides [[buffer(9)]],
1487 const constant int& w_batch_ndims [[buffer(10)]],
1488 const constant int* w_shape [[buffer(11)]],
1489 const constant size_t* w_strides [[buffer(12)]],
1490 const constant size_t* s_strides [[buffer(13)]],
1491 const constant size_t* b_strides [[buffer(14)]],
1492 const constant int& batch_ndims [[buffer(15)]],
1493 const constant int* batch_shape [[buffer(16)]],
1494 const device uint32_t* lhs_indices [[buffer(17)]],
1495 const device uint32_t* rhs_indices [[buffer(18)]],
1496 const constant size_t* lhs_strides [[buffer(19)]],
1497 const constant size_t* rhs_strides [[buffer(20)]],
1498 uint3 tid [[threadgroup_position_in_grid]],
1499 uint simd_gid [[simdgroup_index_in_threadgroup]],
1500 uint simd_lid [[thread_index_in_simdgroup]]) {
1502 x,
1503 w,
1504 scales,
1505 biases,
1506 lhs_indices,
1507 rhs_indices,
1508 y,
1509 out_vec_size,
1510 batch_ndims,
1511 batch_shape,
1512 lhs_strides,
1513 rhs_strides,
1514 x_batch_ndims,
1515 x_shape,
1516 x_strides,
1517 w_batch_ndims,
1518 w_shape,
1519 w_strides,
1520 s_strides,
1521 b_strides,
1522 tid);
1524 w,
1525 scales,
1526 biases,
1527 x,
1528 y,
1529 in_vec_size,
1530 out_vec_size,
1531 tid,
1532 simd_gid,
1533 simd_lid);
1534}
1535
1536template <typename T, int group_size, int bits>
1537[[kernel]] void bs_qmv(
1538 const device uint32_t* w [[buffer(0)]],
1539 const device T* scales [[buffer(1)]],
1540 const device T* biases [[buffer(2)]],
1541 const device T* x [[buffer(3)]],
1542 device T* y [[buffer(4)]],
1543 const constant int& in_vec_size [[buffer(5)]],
1544 const constant int& out_vec_size [[buffer(6)]],
1545 const constant int& x_batch_ndims [[buffer(7)]],
1546 const constant int* x_shape [[buffer(8)]],
1547 const constant size_t* x_strides [[buffer(9)]],
1548 const constant int& w_batch_ndims [[buffer(10)]],
1549 const constant int* w_shape [[buffer(11)]],
1550 const constant size_t* w_strides [[buffer(12)]],
1551 const constant size_t* s_strides [[buffer(13)]],
1552 const constant size_t* b_strides [[buffer(14)]],
1553 const constant int& batch_ndims [[buffer(15)]],
1554 const constant int* batch_shape [[buffer(16)]],
1555 const device uint32_t* lhs_indices [[buffer(17)]],
1556 const device uint32_t* rhs_indices [[buffer(18)]],
1557 const constant size_t* lhs_strides [[buffer(19)]],
1558 const constant size_t* rhs_strides [[buffer(20)]],
1559 uint3 tid [[threadgroup_position_in_grid]],
1560 uint simd_gid [[simdgroup_index_in_threadgroup]],
1561 uint simd_lid [[thread_index_in_simdgroup]]) {
1563 x,
1564 w,
1565 scales,
1566 biases,
1567 lhs_indices,
1568 rhs_indices,
1569 y,
1570 out_vec_size,
1571 batch_ndims,
1572 batch_shape,
1573 lhs_strides,
1574 rhs_strides,
1575 x_batch_ndims,
1576 x_shape,
1577 x_strides,
1578 w_batch_ndims,
1579 w_shape,
1580 w_strides,
1581 s_strides,
1582 b_strides,
1583 tid);
1585 w,
1586 scales,
1587 biases,
1588 x,
1589 y,
1590 in_vec_size,
1591 out_vec_size,
1592 tid,
1593 simd_gid,
1594 simd_lid);
1595}
1596
1597template <typename T, int group_size, int bits>
1598[[kernel]] void bs_qvm(
1599 const device uint32_t* w [[buffer(0)]],
1600 const device T* scales [[buffer(1)]],
1601 const device T* biases [[buffer(2)]],
1602 const device T* x [[buffer(3)]],
1603 device T* y [[buffer(4)]],
1604 const constant int& in_vec_size [[buffer(5)]],
1605 const constant int& out_vec_size [[buffer(6)]],
1606 const constant int& x_batch_ndims [[buffer(7)]],
1607 const constant int* x_shape [[buffer(8)]],
1608 const constant size_t* x_strides [[buffer(9)]],
1609 const constant int& w_batch_ndims [[buffer(10)]],
1610 const constant int* w_shape [[buffer(11)]],
1611 const constant size_t* w_strides [[buffer(12)]],
1612 const constant size_t* s_strides [[buffer(13)]],
1613 const constant size_t* b_strides [[buffer(14)]],
1614 const constant int& batch_ndims [[buffer(15)]],
1615 const constant int* batch_shape [[buffer(16)]],
1616 const device uint32_t* lhs_indices [[buffer(17)]],
1617 const device uint32_t* rhs_indices [[buffer(18)]],
1618 const constant size_t* lhs_strides [[buffer(19)]],
1619 const constant size_t* rhs_strides [[buffer(20)]],
1620 uint3 tid [[threadgroup_position_in_grid]],
1621 uint simd_gid [[simdgroup_index_in_threadgroup]],
1622 uint simd_lid [[thread_index_in_simdgroup]]) {
1624 x,
1625 w,
1626 scales,
1627 biases,
1628 lhs_indices,
1629 rhs_indices,
1630 y,
1631 out_vec_size,
1632 batch_ndims,
1633 batch_shape,
1634 lhs_strides,
1635 rhs_strides,
1636 x_batch_ndims,
1637 x_shape,
1638 x_strides,
1639 w_batch_ndims,
1640 w_shape,
1641 w_strides,
1642 s_strides,
1643 b_strides,
1644 tid);
1646 w,
1647 scales,
1648 biases,
1649 x,
1650 y,
1651 in_vec_size,
1652 out_vec_size,
1653 tid,
1654 simd_gid,
1655 simd_lid);
1656}
1657
1658template <
1659 typename T,
1660 const int group_size,
1661 const int bits,
1662 const bool aligned_N,
1663 const int BM = 32,
1664 const int BK = 32,
1665 const int BN = 32>
1666[[kernel]] void bs_qmm_t(
1667 const device uint32_t* w [[buffer(0)]],
1668 const device T* scales [[buffer(1)]],
1669 const device T* biases [[buffer(2)]],
1670 const device T* x [[buffer(3)]],
1671 device T* y [[buffer(4)]],
1672 const constant int& K [[buffer(5)]],
1673 const constant int& N [[buffer(6)]],
1674 const constant int& M [[buffer(7)]],
1675 const constant int& x_batch_ndims [[buffer(8)]],
1676 const constant int* x_shape [[buffer(9)]],
1677 const constant size_t* x_strides [[buffer(10)]],
1678 const constant int& w_batch_ndims [[buffer(11)]],
1679 const constant int* w_shape [[buffer(12)]],
1680 const constant size_t* w_strides [[buffer(13)]],
1681 const constant size_t* s_strides [[buffer(14)]],
1682 const constant size_t* b_strides [[buffer(15)]],
1683 const constant int& batch_ndims [[buffer(16)]],
1684 const constant int* batch_shape [[buffer(17)]],
1685 const device uint32_t* lhs_indices [[buffer(18)]],
1686 const device uint32_t* rhs_indices [[buffer(19)]],
1687 const constant size_t* lhs_strides [[buffer(20)]],
1688 const constant size_t* rhs_strides [[buffer(21)]],
1689 uint3 tid [[threadgroup_position_in_grid]],
1690 uint lid [[thread_index_in_threadgroup]],
1691 uint simd_gid [[simdgroup_index_in_threadgroup]],
1692 uint simd_lid [[thread_index_in_simdgroup]]) {
1693 (void)lid;
1694
1695 constexpr int BK_padded = (BK + 16 / sizeof(T));
1696
1697 threadgroup T Xs[BM * BK_padded];
1698 threadgroup T Ws[BN * BK_padded];
1699
1701 x,
1702 w,
1703 scales,
1704 biases,
1705 lhs_indices,
1706 rhs_indices,
1707 y,
1708 M * N,
1709 batch_ndims,
1710 batch_shape,
1711 lhs_strides,
1712 rhs_strides,
1713 x_batch_ndims,
1714 x_shape,
1715 x_strides,
1716 w_batch_ndims,
1717 w_shape,
1718 w_strides,
1719 s_strides,
1720 b_strides,
1721 tid);
1723 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1724}
1725
1726template <
1727 typename T,
1728 const int group_size,
1729 const int bits,
1730 const int BM = 32,
1731 const int BK = 32,
1732 const int BN = 32>
1733[[kernel]] void bs_qmm_n(
1734 const device uint32_t* w [[buffer(0)]],
1735 const device T* scales [[buffer(1)]],
1736 const device T* biases [[buffer(2)]],
1737 const device T* x [[buffer(3)]],
1738 device T* y [[buffer(4)]],
1739 const constant int& K [[buffer(5)]],
1740 const constant int& N [[buffer(6)]],
1741 const constant int& M [[buffer(7)]],
1742 const constant int& x_batch_ndims [[buffer(8)]],
1743 const constant int* x_shape [[buffer(9)]],
1744 const constant size_t* x_strides [[buffer(10)]],
1745 const constant int& w_batch_ndims [[buffer(11)]],
1746 const constant int* w_shape [[buffer(12)]],
1747 const constant size_t* w_strides [[buffer(13)]],
1748 const constant size_t* s_strides [[buffer(14)]],
1749 const constant size_t* b_strides [[buffer(15)]],
1750 const constant int& batch_ndims [[buffer(16)]],
1751 const constant int* batch_shape [[buffer(17)]],
1752 const device uint32_t* lhs_indices [[buffer(18)]],
1753 const device uint32_t* rhs_indices [[buffer(19)]],
1754 const constant size_t* lhs_strides [[buffer(20)]],
1755 const constant size_t* rhs_strides [[buffer(21)]],
1756 uint3 tid [[threadgroup_position_in_grid]],
1757 uint lid [[thread_index_in_threadgroup]],
1758 uint simd_gid [[simdgroup_index_in_threadgroup]],
1759 uint simd_lid [[thread_index_in_simdgroup]]) {
1760 (void)lid;
1761
1762 constexpr int BK_padded = (BK + 16 / sizeof(T));
1763 constexpr int BN_padded = (BN + 16 / sizeof(T));
1764
1765 threadgroup T Xs[BM * BK_padded];
1766 threadgroup T Ws[BK * BN_padded];
1767
1769 x,
1770 w,
1771 scales,
1772 biases,
1773 lhs_indices,
1774 rhs_indices,
1775 y,
1776 M * N,
1777 batch_ndims,
1778 batch_shape,
1779 lhs_strides,
1780 rhs_strides,
1781 x_batch_ndims,
1782 x_shape,
1783 x_strides,
1784 w_batch_ndims,
1785 w_shape,
1786 w_strides,
1787 s_strides,
1788 b_strides,
1789 tid);
1791 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1792}
1793
1794template <typename T, const int group_size, const int bits>
1795[[kernel]] void affine_quantize(
1796 const device T* w [[buffer(0)]],
1797 device uint8_t* out [[buffer(1)]],
1798 device T* scales [[buffer(2)]],
1799 device T* biases [[buffer(3)]],
1800 uint2 index [[thread_position_in_grid]],
1801 uint2 grid_dim [[threads_per_grid]]) {
1802 constexpr T eps = T(1e-7);
1803 constexpr int simd_size = 32;
1804 constexpr int uint8_bits = 8;
1805 constexpr T n_bins = (1 << bits) - 1;
1806 constexpr int packs_per_int = uint8_bits / bits;
1807 constexpr int values_per_reduce = group_size / simd_size;
1808 constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
1809 constexpr int writes_per_pack =
1810 writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
1811
1812 static_assert(
1813 group_size % simd_size == 0,
1814 "Group size must be divisible by simd size.");
1815
1816 size_t offset = index.x + grid_dim.x * size_t(index.y);
1817 size_t in_index = offset * values_per_reduce;
1818 size_t out_index = offset * writes_per_pack;
1819
1820 T w_thread[values_per_reduce];
1821 T w_min = Limits<T>::max;
1822 T w_max = 0;
1823
1824#pragma clang loop unroll(full)
1825 for (int i = 0; i < values_per_reduce; i++) {
1826 T val = w[in_index + i];
1827 w_thread[i] = val;
1828 w_min = min(w_min, val);
1829 w_max = max(w_max, val);
1830 }
1831
1832 w_min = simd_min(w_min);
1833 w_max = simd_max(w_max);
1834
1835 T scale = max((w_max - w_min) / n_bins, eps);
1836 bool side = abs(w_min) > abs(w_max);
1837 scale = side ? scale : -scale;
1838 T edge = side ? w_min : w_max;
1839 T q0 = round(edge / scale);
1840 bool at_zero = q0 == 0.0f;
1841 scale = at_zero ? scale : edge / q0;
1842 T bias = at_zero ? T(0) : edge;
1843
1844 // Write out the scales and biases
1845 size_t gindex = in_index / group_size;
1846 if (in_index % group_size == 0) {
1847 scales[gindex] = scale;
1848 biases[gindex] = bias;
1849 }
1850
1851 uint8_t output = 0;
1852#pragma clang loop unroll(full)
1853 for (int i = 0; i < values_per_reduce; i++) {
1854 uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
1855 if (bits == 8) {
1856 output = val;
1857 } else {
1858 output += val << (bits * (i % packs_per_int));
1859 }
1860
1861 if (packs_per_int < values_per_reduce &&
1862 i % packs_per_int == packs_per_int - 1) {
1863 out[out_index + i / packs_per_int] = output;
1864 output = 0;
1865 } else {
1866#pragma clang loop unroll(full)
1867 for (int j = 0; j < writes_per_reduce - 1; j++) {
1868 uint8_t sval = simd_shuffle_down(val, j + 1);
1869 output += sval << (bits * (values_per_reduce + j + i));
1870 }
1871 }
1872 }
1873 if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
1874 out[out_index / writes_per_reduce] = output;
1875 }
1876}
1877
1878template <typename T, const int group_size, const int bits>
1880 const device T* w [[buffer(0)]],
1881 const device T* scales [[buffer(1)]],
1882 const device T* biases [[buffer(2)]],
1883 device uint8_t* out [[buffer(3)]],
1884 uint2 index [[thread_position_in_grid]],
1885 uint2 grid_dim [[threads_per_grid]]) {
1886 constexpr int uint8_bits = 8;
1887 constexpr int packs_per_int = uint8_bits / bits;
1888 constexpr T n_bins = (1 << bits) - 1;
1889
1890 size_t offset = index.x + grid_dim.x * size_t(index.y);
1891 size_t in_index = offset * packs_per_int;
1892 size_t gindex = in_index / group_size;
1893
1894 T scale = scales[gindex];
1895 T bias = biases[gindex];
1896
1897 uint8_t output = 0;
1898#pragma clang loop unroll(full)
1899 for (int i = 0; i < packs_per_int; i++) {
1900 uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins);
1901 if (bits == 8) {
1902 output = val;
1903 } else {
1904 output += val << (bits * i);
1905 }
1906 }
1907 out[offset] = output;
1908}
1909
1910template <typename T, const int group_size, const int bits>
1911[[kernel]] void affine_dequantize(
1912 const device uint8_t* w [[buffer(0)]],
1913 const device T* scales [[buffer(1)]],
1914 const device T* biases [[buffer(2)]],
1915 device T* out [[buffer(3)]],
1916 uint2 index [[thread_position_in_grid]],
1917 uint2 grid_dim [[threads_per_grid]]) {
1918 constexpr int uint8_bits = 8;
1919 constexpr int packs_per_int = uint8_bits / bits;
1920
1921 size_t offset = index.x + grid_dim.x * size_t(index.y);
1922 size_t oindex = offset * packs_per_int;
1923 size_t gindex = oindex / group_size;
1924 T scale = scales[gindex];
1925 T bias = biases[gindex];
1926 uint val = w[offset];
1927
1928#pragma clang loop unroll(full)
1929 for (int i = 0; i < packs_per_int; i++) {
1930 uint8_t d;
1931 if (bits == 2) {
1932 d = (val >> (bits * i)) & 0x03;
1933 } else if (bits == 4) {
1934 d = (val >> (bits * i)) & 0x0f;
1935 } else if (bits == 8) {
1936 d = val;
1937 }
1938 out[oindex + i] = scale * d + bias;
1939 }
1940}
static constant constexpr const uint8_t simd_size
Definition ops.h:22
METAL_FUNC ulong2 elem_to_loc_broadcast(uint elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
Definition utils.h:7
METAL_FUNC stride_t elem_to_loc(uint elem, constant const int *shape, constant const stride_t *strides, int ndim)
Definition utils.h:87
Definition bf16.h:265
METAL_FUNC bfloat16_t simd_max(bfloat16_t data)
Definition bf16_math.h:392
METAL_FUNC bfloat16_t round(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t simd_sum(bfloat16_t data)
Definition bf16_math.h:392
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t abs(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t simd_min(bfloat16_t data)
Definition bf16_math.h:392
METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta)
Definition bf16_math.h:391
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
#define MLX_MTL_CONST
Definition quantized.h:8
U qdot_safe(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum, int N)
Definition quantized.h:142
METAL_FUNC void qmm_n_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:879
METAL_FUNC void qvm_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const int in_vec_size, const int out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:647
void bs_qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1733
void qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1423
void affine_quantize(const device T *w, device uint8_t *out, device T *scales, device T *biases, uint2 index, uint2 grid_dim)
Definition quantized.h:1795
void bs_qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1476
void affine_dequantize(const device uint8_t *w, const device T *scales, const device T *biases, device T *out, uint2 index, uint2 grid_dim)
Definition quantized.h:1911
static constant constexpr const int SIMD_SIZE
Definition quantized.h:10
void qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1200
void bs_qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1598
void affine_quantize_scales_biases(const device T *w, const device T *scales, const device T *biases, device uint8_t *out, uint2 index, uint2 grid_dim)
Definition quantized.h:1879
void qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1149
void qmv_quad(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:1098
static constant constexpr const int QUAD_SIZE
Definition quantized.h:11
U load_vector(const device T *x, thread U *x_thread)
Definition quantized.h:14
METAL_FUNC void qmv_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:498
U load_vector_safe(const device T *x, thread U *x_thread, int N)
Definition quantized.h:52
void bs_qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1666
U qdot(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum)
Definition quantized.h:99
void qvm_split_k(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &final_block_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1302
METAL_FUNC void qmv_fast_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:434
void qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1365
METAL_FUNC void adjust_matrix_offsets(const device T *&x, const device uint32_t *&w, const device T *&scales, const device T *&biases, device T *&y, int output_stride, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid)
Definition quantized.h:1005
void bs_qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1537
METAL_FUNC void qmv_quad_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:376
void qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1251
void qouter(const thread uint8_t *w, U x, U scale, U bias, thread U *result)
Definition quantized.h:187
void dequantize(const device uint8_t *w, U scale, U bias, threadgroup U *w_local)
Definition quantized.h:219
METAL_FUNC void qmm_t_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:758
Definition utils.h:17
Definition quantized.h:262
const int group_stride
Definition quantized.h:282
static constant constexpr const short BCOLS_PACKED
Definition quantized.h:274
const device T * biases
Definition quantized.h:291
short group_step_cnt
Definition quantized.h:281
static constant constexpr const short group_steps
Definition quantized.h:277
const short thread_idx
Definition quantized.h:284
const device T * scales
Definition quantized.h:290
static constant constexpr const short n_reads
Definition quantized.h:275
void next()
Definition quantized.h:354
void load_safe(short2 src_tile_dim) const
Definition quantized.h:327
const int src_ld
Definition quantized.h:279
const short bi
Definition quantized.h:285
void load_unsafe() const
Definition quantized.h:314
static constant constexpr const short pack_factor
Definition quantized.h:273
threadgroup T * dst
Definition quantized.h:288
const int tile_stride
Definition quantized.h:280
const device uint32_t * src
Definition quantized.h:289
const short bj
Definition quantized.h:286
QuantizedBlockLoader(const device uint32_t *src_, const device T *scales_, const device T *biases_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition quantized.h:293
Definition loader.h:25