MLX
Loading...
Searching...
No Matches
quantized.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#include <metal_simdgroup>
4#include <metal_stdlib>
5
6using namespace metal;
7
8#define MLX_MTL_CONST static constant constexpr const
9
12
13template <typename T, typename U, int values_per_thread, int bits>
14inline U load_vector(const device T* x, thread U* x_thread) {
15 static_assert(
16 bits == 2 || bits == 4 || bits == 8,
17 "Template undefined for bits not in {2, 4, 8}");
18
19 U sum = 0;
20
21 if (bits == 2) {
22 for (int i = 0; i < values_per_thread; i += 4) {
23 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
24 x_thread[i] = x[i];
25 x_thread[i + 1] = x[i + 1] / 4.0f;
26 x_thread[i + 2] = x[i + 2] / 16.0f;
27 x_thread[i + 3] = x[i + 3] / 64.0f;
28 }
29 }
30
31 else if (bits == 4) {
32 for (int i = 0; i < values_per_thread; i += 4) {
33 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
34 x_thread[i] = x[i];
35 x_thread[i + 1] = x[i + 1] / 16.0f;
36 x_thread[i + 2] = x[i + 2] / 256.0f;
37 x_thread[i + 3] = x[i + 3] / 4096.0f;
38 }
39 }
40
41 else if (bits == 8) {
42 for (int i = 0; i < values_per_thread; i++) {
43 sum += x[i];
44 x_thread[i] = x[i];
45 }
46 }
47
48 return sum;
49}
50
51template <typename T, typename U, int values_per_thread, int bits>
52inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
53 static_assert(
54 bits == 2 || bits == 4 || bits == 8,
55 "Template undefined for bits not in {2, 4, 8}");
56
57 U sum = 0;
58
59 if (bits == 2) {
60 for (int i = 0; i < N; i += 4) {
61 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
62 x_thread[i] = x[i];
63 x_thread[i + 1] = x[i + 1] / 4.0f;
64 x_thread[i + 2] = x[i + 2] / 16.0f;
65 x_thread[i + 3] = x[i + 3] / 64.0f;
66 }
67 for (int i = N; i < values_per_thread; i++) {
68 x_thread[i] = 0;
69 }
70 }
71
72 else if (bits == 4) {
73 for (int i = 0; i < N; i += 4) {
74 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
75 x_thread[i] = x[i];
76 x_thread[i + 1] = x[i + 1] / 16.0f;
77 x_thread[i + 2] = x[i + 2] / 256.0f;
78 x_thread[i + 3] = x[i + 3] / 4096.0f;
79 }
80 for (int i = N; i < values_per_thread; i++) {
81 x_thread[i] = 0;
82 }
83 }
84
85 else if (bits == 8) {
86 for (int i = 0; i < N; i++) {
87 sum += x[i];
88 x_thread[i] = x[i];
89 }
90 for (int i = N; i < values_per_thread; i++) {
91 x_thread[i] = 0;
92 }
93 }
94
95 return sum;
96}
97
98template <typename U, int values_per_thread, int bits>
99inline U qdot(
100 const device uint8_t* w,
101 const thread U* x_thread,
102 U scale,
103 U bias,
104 U sum) {
105 static_assert(
106 bits == 2 || bits == 4 || bits == 8,
107 "Template undefined for bits not in {2, 4, 8}");
108
109 U accum = 0;
110
111 if (bits == 2) {
112 for (int i = 0; i < (values_per_thread / 4); i++) {
113 accum +=
114 (x_thread[4 * i] * (w[i] & 0x03) +
115 x_thread[4 * i + 1] * (w[i] & 0x0c) +
116 x_thread[4 * i + 2] * (w[i] & 0x30) +
117 x_thread[4 * i + 3] * (w[i] & 0xc0));
118 }
119 }
120
121 else if (bits == 4) {
122 const device uint16_t* ws = (const device uint16_t*)w;
123 for (int i = 0; i < (values_per_thread / 4); i++) {
124 accum +=
125 (x_thread[4 * i] * (ws[i] & 0x000f) +
126 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
127 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
128 x_thread[4 * i + 3] * (ws[i] & 0xf000));
129 }
130 }
131
132 else if (bits == 8) {
133 for (int i = 0; i < values_per_thread; i++) {
134 accum += x_thread[i] * w[i];
135 }
136 }
137
138 return scale * accum + sum * bias;
139}
140
141template <typename U, int values_per_thread, int bits>
142inline U qdot_safe(
143 const device uint8_t* w,
144 const thread U* x_thread,
145 U scale,
146 U bias,
147 U sum,
148 int N) {
149 static_assert(
150 bits == 2 || bits == 4 || bits == 8,
151 "Template undefined for bits not in {2, 4, 8}");
152
153 U accum = 0;
154
155 if (bits == 2) {
156 for (int i = 0; i < (N / 4); i++) {
157 accum +=
158 (x_thread[4 * i] * (w[i] & 0x03) +
159 x_thread[4 * i + 1] * (w[i] & 0x0c) +
160 x_thread[4 * i + 2] * (w[i] & 0x30) +
161 x_thread[4 * i + 3] * (w[i] & 0xc0));
162 }
163 }
164
165 else if (bits == 4) {
166 const device uint16_t* ws = (const device uint16_t*)w;
167 for (int i = 0; i < (N / 4); i++) {
168 accum +=
169 (x_thread[4 * i] * (ws[i] & 0x000f) +
170 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
171 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
172 x_thread[4 * i + 3] * (ws[i] & 0xf000));
173 }
174 }
175
176 else if (bits == 8) {
177 for (int i = 0; i < N; i++) {
178 accum += x_thread[i] * w[i];
179 }
180 }
181
182 return scale * accum + sum * bias;
183}
184
185template <typename U, int values_per_thread, int bits>
186inline void
187qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
188 static_assert(
189 bits == 2 || bits == 4 || bits == 8,
190 "Template undefined for bits not in {2, 4, 8}");
191
192 if (bits == 2) {
193 U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
194 for (int i = 0; i < (values_per_thread / 4); i++) {
195 result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
196 result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
197 result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
198 result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
199 }
200 }
201
202 else if (bits == 4) {
203 U s[2] = {scale, scale / 16.0f};
204 for (int i = 0; i < (values_per_thread / 2); i++) {
205 result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
206 result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
207 }
208 }
209
210 else if (bits == 8) {
211 for (int i = 0; i < values_per_thread; i++) {
212 result[i] += x * (scale * w[i] + bias);
213 }
214 }
215}
216
217template <typename U, int N, int bits>
218inline void
219dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
220 static_assert(
221 bits == 2 || bits == 4 || bits == 8,
222 "Template undefined for bits not in {2, 4, 8}");
223
224 if (bits == 2) {
225 U s[4] = {
226 scale,
227 scale / static_cast<U>(4.0f),
228 scale / static_cast<U>(16.0f),
229 scale / static_cast<U>(64.0f)};
230 for (int i = 0; i < (N / 4); i++) {
231 w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
232 w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
233 w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
234 w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
235 }
236 }
237
238 else if (bits == 4) {
239 U s[2] = {scale, scale / static_cast<U>(16.0f)};
240 for (int i = 0; i < (N / 2); i++) {
241 w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
242 w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
243 }
244 }
245
246 else if (bits == 8) {
247 for (int i = 0; i < N; i++) {
248 w_local[i] = scale * w[i] + bias;
249 }
250 }
251}
252
253template <
254 typename T,
255 short BROWS,
256 short BCOLS,
257 short dst_ld,
258 short reduction_dim,
259 short tgp_size,
260 short group_size,
261 short bits>
263 static_assert(
264 BCOLS <= group_size,
265 "The group size should be larger than the columns");
266 static_assert(
267 group_size % BCOLS == 0,
268 "The group size should be divisible by the columns");
269 static_assert(
270 bits == 2 || bits == 4 || bits == 8,
271 "Template undefined for bits not in {2, 4, 8}");
272
273 MLX_MTL_CONST short pack_factor = 32 / bits;
276 (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
277 MLX_MTL_CONST short group_steps = group_size / BCOLS;
278
279 const int src_ld;
280 const int tile_stride;
282 const int group_stride;
283
284 const short thread_idx;
285 const short bi;
286 const short bj;
287
288 threadgroup T* dst;
289 const device uint32_t* src;
290 const device T* scales;
291 const device T* biases;
292
294 const device uint32_t* src_,
295 const device T* scales_,
296 const device T* biases_,
297 const int src_ld_,
298 threadgroup T* dst_,
299 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
300 ushort simd_lane_id [[thread_index_in_simdgroup]])
301 : src_ld(src_ld_),
303 reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
305 group_stride(BROWS * src_ld / group_size),
306 thread_idx(simd_group_id * 32 + simd_lane_id),
309 dst(dst_ + bi * dst_ld + bj * pack_factor),
310 src(src_ + bi * src_ld / pack_factor + bj),
311 scales(scales_ + bi * src_ld / group_size),
312 biases(biases_ + bi * src_ld / group_size) {}
313
314 void load_unsafe() const {
315 if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
316 return;
317 }
318
319 T scale = *scales;
320 T bias = *biases;
321 for (int i = 0; i < n_reads; i++) {
323 (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
324 }
325 }
326
327 void load_safe(short2 src_tile_dim) const {
328 if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
329 return;
330 }
331
332 if (reduction_dim == 1 && bi >= src_tile_dim.y) {
333 for (int i = 0; i < n_reads * pack_factor; i++) {
334 dst[i] = T(0);
335 }
336 return;
337 }
338
339 if (reduction_dim == 0 && bi >= src_tile_dim.x) {
340 for (int i = 0; i < n_reads * pack_factor; i++) {
341 dst[i] = T(0);
342 }
343 return;
344 }
345
346 T scale = *scales;
347 T bias = *biases;
348 for (int i = 0; i < n_reads; i++) {
350 (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
351 }
352 }
353
354 void next() {
355 src += tile_stride;
356 if (reduction_dim == 1) {
357 if (group_steps > 1) {
360 group_step_cnt = 0;
361 scales++;
362 biases++;
363 }
364 } else {
365 scales++;
366 biases++;
367 }
368 } else {
371 }
372 }
373};
374
375template <typename T, int group_size, int bits, int D>
376METAL_FUNC void qmv_quad_impl(
377 const device uint32_t* w,
378 const device T* scales,
379 const device T* biases,
380 const device T* x,
381 device T* y,
382 constant int& in_vec_size,
383 const constant int& out_vec_size,
384 uint3 tid [[threadgroup_position_in_grid]],
385 uint quad_gid [[quadgroup_index_in_threadgroup]],
386 uint quad_lid [[thread_index_in_quadgroup]]) {
387 constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
388 constexpr int pack_factor = 32 / bits;
389 constexpr int values_per_thread = D / QUAD_SIZE;
390 constexpr int packs_per_thread = values_per_thread / pack_factor;
391 constexpr int scale_step_per_thread = group_size / values_per_thread;
392 constexpr int results_per_quadgroup = 8;
393
394 typedef float U;
395
396 thread U x_thread[values_per_thread];
397 thread U result[results_per_quadgroup] = {0};
398
399 // Adjust positions
400 const int in_vec_size_w = in_vec_size / pack_factor;
401 const int in_vec_size_g = in_vec_size / group_size;
402 const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
403
404 w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
405 scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
406 biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
407 x += tid.y * in_vec_size + quad_lid * values_per_thread;
408 y += tid.y * out_vec_size + out_row;
409
411
412 for (int row = 0; row < results_per_quadgroup; row++) {
413 const device uint8_t* wl =
414 (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
415 const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
416 const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
417
418 U s = sl[0];
419 U b = bl[0];
420 if (row * quads_per_simd + out_row < out_vec_size) {
421 result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
422 }
423 }
424
425 for (int row = 0; row < results_per_quadgroup; row++) {
426 result[row] = quad_sum(result[row]);
427 if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
428 y[row * quads_per_simd] = static_cast<T>(result[row]);
429 }
430 }
431}
432
433template <typename T, int group_size, int bits>
434METAL_FUNC void qmv_fast_impl(
435 const device uint32_t* w,
436 const device T* scales,
437 const device T* biases,
438 const device T* x,
439 device T* y,
440 const constant int& in_vec_size,
441 const constant int& out_vec_size,
442 uint3 tid [[threadgroup_position_in_grid]],
443 uint simd_gid [[simdgroup_index_in_threadgroup]],
444 uint simd_lid [[thread_index_in_simdgroup]]) {
445 constexpr int packs_per_thread = bits > 2 ? 2 : 1;
446 constexpr int num_simdgroups = 2;
447 constexpr int results_per_simdgroup = 4;
448 constexpr int pack_factor = 32 / bits;
449 constexpr int values_per_thread = pack_factor * packs_per_thread;
450 constexpr int block_size = values_per_thread * SIMD_SIZE;
451 constexpr int scale_step_per_thread = group_size / values_per_thread;
452
453 typedef float U;
454
455 thread U x_thread[values_per_thread];
456 thread U result[results_per_simdgroup] = {0};
457
458 // Adjust positions
459 const int in_vec_size_w = in_vec_size / pack_factor;
460 const int in_vec_size_g = in_vec_size / group_size;
461 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
462 simd_gid * results_per_simdgroup;
463 w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
464 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
465 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
466 x += tid.y * in_vec_size + simd_lid * values_per_thread;
467 y += tid.y * out_vec_size + out_row;
468
469 for (int k = 0; k < in_vec_size; k += block_size) {
471
472 for (int row = 0; row < results_per_simdgroup; row++) {
473 const device uint8_t* wl =
474 (const device uint8_t*)(w + row * in_vec_size_w);
475 const device T* sl = scales + row * in_vec_size_g;
476 const device T* bl = biases + row * in_vec_size_g;
477
478 U s = sl[0];
479 U b = bl[0];
480 result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
481 }
482
483 w += block_size / pack_factor;
484 scales += block_size / group_size;
485 biases += block_size / group_size;
486 x += block_size;
487 }
488
489 for (int row = 0; row < results_per_simdgroup; row++) {
490 result[row] = simd_sum(result[row]);
491 if (simd_lid == 0) {
492 y[row] = static_cast<T>(result[row]);
493 }
494 }
495}
496
497template <typename T, int group_size, int bits>
498METAL_FUNC void qmv_impl(
499 const device uint32_t* w,
500 const device T* scales,
501 const device T* biases,
502 const device T* x,
503 device T* y,
504 const constant int& in_vec_size,
505 const constant int& out_vec_size,
506 uint3 tid [[threadgroup_position_in_grid]],
507 uint simd_gid [[simdgroup_index_in_threadgroup]],
508 uint simd_lid [[thread_index_in_simdgroup]]) {
509 constexpr int num_simdgroups = 2;
510 constexpr int results_per_simdgroup = 4;
511 constexpr int packs_per_thread = 1;
512 constexpr int pack_factor = 32 / bits;
513 constexpr int values_per_thread = pack_factor * packs_per_thread;
514 constexpr int block_size = values_per_thread * SIMD_SIZE;
515 constexpr int scale_step_per_thread = group_size / values_per_thread;
516
517 typedef float U;
518
519 thread U x_thread[values_per_thread];
520 thread U result[results_per_simdgroup] = {0};
521
522 // Adjust positions
523 const int in_vec_size_w = in_vec_size / pack_factor;
524 const int in_vec_size_g = in_vec_size / group_size;
525 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
526 simd_gid * results_per_simdgroup;
527 const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
528
529 if (out_row >= out_vec_size) {
530 return;
531 }
532
533 // In this case we need to properly guard all our reads because there isn't
534 // even 1 tile in the matrix
535 if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
536 w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
537 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
538 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
539 x += tid.y * in_vec_size + simd_lid * values_per_thread;
540 y += tid.y * out_vec_size + out_row;
541
542 int k = 0;
543 for (; k < in_vec_size - block_size; k += block_size) {
545
546 for (int row = 0; out_row + row < out_vec_size; row++) {
547 const device uint8_t* wl =
548 (const device uint8_t*)(w + row * in_vec_size_w);
549 const device T* sl = scales + row * in_vec_size_g;
550 const device T* bl = biases + row * in_vec_size_g;
551
552 U s = sl[0];
553 U b = bl[0];
554 result[row] +=
555 qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
556 }
557
558 w += block_size / pack_factor;
559 scales += block_size / group_size;
560 biases += block_size / group_size;
561 x += block_size;
562 }
563 const int remaining = clamp(
564 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
565 0,
566 values_per_thread);
567 U sum =
569
570 for (int row = 0; out_row + row < out_vec_size; row++) {
571 const device uint8_t* wl =
572 (const device uint8_t*)(w + row * in_vec_size_w);
573 const device T* sl = scales + row * in_vec_size_g;
574 const device T* bl = biases + row * in_vec_size_g;
575
576 U s = sl[0];
577 U b = bl[0];
578 result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
579 }
580
581 for (int row = 0; out_row + row < out_vec_size; row++) {
582 result[row] = simd_sum(result[row]);
583 if (simd_lid == 0) {
584 y[row] = static_cast<T>(result[row]);
585 }
586 }
587 }
588
589 // In this case the last tile is moved back to redo some output values
590 else {
591 w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread;
592 scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
593 biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
594 x += tid.y * in_vec_size + simd_lid * values_per_thread;
595 y += tid.y * out_vec_size + used_out_row;
596
597 int k = 0;
598 for (; k < in_vec_size - block_size; k += block_size) {
600
601 for (int row = 0; row < results_per_simdgroup; row++) {
602 const device uint8_t* wl =
603 (const device uint8_t*)(w + row * in_vec_size_w);
604 const device T* sl = scales + row * in_vec_size_g;
605 const device T* bl = biases + row * in_vec_size_g;
606
607 U s = sl[0];
608 U b = bl[0];
609 result[row] +=
610 qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
611 }
612
613 w += block_size / pack_factor;
614 scales += block_size / group_size;
615 biases += block_size / group_size;
616 x += block_size;
617 }
618 const int remaining = clamp(
619 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
620 0,
621 values_per_thread);
622 U sum =
624
625 for (int row = 0; row < results_per_simdgroup; row++) {
626 const device uint8_t* wl =
627 (const device uint8_t*)(w + row * in_vec_size_w);
628 const device T* sl = scales + row * in_vec_size_g;
629 const device T* bl = biases + row * in_vec_size_g;
630
631 U s = sl[0];
632 U b = bl[0];
634 wl, x_thread, s, b, sum, remaining);
635 }
636
637 for (int row = 0; row < results_per_simdgroup; row++) {
638 result[row] = simd_sum(result[row]);
639 if (simd_lid == 0) {
640 y[row] = static_cast<T>(result[row]);
641 }
642 }
643 }
644}
645
646template <typename T, const int group_size, const int bits>
647METAL_FUNC void qvm_impl(
648 const device uint32_t* w,
649 const device T* scales,
650 const device T* biases,
651 const device T* x,
652 device T* y,
653 const constant int& in_vec_size,
654 const constant int& out_vec_size,
655 uint3 tid [[threadgroup_position_in_grid]],
656 uint simd_gid [[simdgroup_index_in_threadgroup]],
657 uint simd_lid [[thread_index_in_simdgroup]]) {
658 constexpr int num_simdgroups = 2;
659 constexpr int pack_factor = 32 / bits;
660 constexpr int tn = 32 / pack_factor;
661 constexpr int blocksize = SIMD_SIZE;
662
663 typedef float U;
664 typedef struct {
665 uint32_t wi[tn];
666 } vec_w;
667
668 thread vec_w w_local;
669 thread U result[tn * pack_factor] = {0};
670 thread U scale = 1;
671 thread U bias = 0;
672 thread U x_local = 0;
673
674 // Adjust positions
675 const int out_vec_size_w = out_vec_size / pack_factor;
676 const int out_vec_size_g = out_vec_size / group_size;
677 int out_col =
678 tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
679 w += out_col / pack_factor + simd_lid * out_vec_size_w;
680 scales += out_col / group_size + simd_lid * out_vec_size_g;
681 biases += out_col / group_size + simd_lid * out_vec_size_g;
682 x += tid.y * in_vec_size + simd_lid;
683 y += tid.y * out_vec_size + out_col;
684
685 if (out_col >= out_vec_size) {
686 return;
687 }
688
689 // Loop over in_vec in blocks of blocksize
690 int remaining = in_vec_size % blocksize;
691 if (remaining == 0) {
692 for (int i = 0; i < in_vec_size; i += blocksize) {
693 x_local = *x;
694 scale = *scales;
695 bias = *biases;
696 w_local = *((device vec_w*)w);
697
699 (thread uint8_t*)&w_local, x_local, scale, bias, result);
700
701 x += blocksize;
702 scales += blocksize * out_vec_size_g;
703 biases += blocksize * out_vec_size_g;
704 w += blocksize * out_vec_size_w;
705 }
706 } else {
707 for (int i = blocksize; i < in_vec_size; i += blocksize) {
708 x_local = *x;
709 scale = *scales;
710 bias = *biases;
711 w_local = *((device vec_w*)w);
712
714 (thread uint8_t*)&w_local, x_local, scale, bias, result);
715
716 x += blocksize;
717 scales += blocksize * out_vec_size_g;
718 biases += blocksize * out_vec_size_g;
719 w += blocksize * out_vec_size_w;
720 }
721 if (static_cast<int>(simd_lid) < remaining) {
722 x_local = *x;
723 scale = *scales;
724 bias = *biases;
725 w_local = *((device vec_w*)w);
726 } else {
727 x_local = 0;
728 scale = 0;
729 bias = 0;
730 }
732 (thread uint8_t*)&w_local, x_local, scale, bias, result);
733 }
734
735// Accumulate in the simdgroup
736#pragma clang loop unroll(full)
737 for (int k = 0; k < tn * pack_factor; k++) {
738 result[k] = simd_sum(result[k]);
739 }
740
741 // Store the result
742 if (simd_lid == 0) {
743#pragma clang loop unroll(full)
744 for (int k = 0; k < tn * pack_factor; k++) {
745 y[k] = static_cast<T>(result[k]);
746 }
747 }
748}
749
750template <
751 typename T,
752 const int group_size,
753 const int bits,
754 const bool aligned_N,
755 const int BM = 32,
756 const int BK = 32,
757 const int BN = 32>
758METAL_FUNC void qmm_t_impl(
759 const device uint32_t* w,
760 const device T* scales,
761 const device T* biases,
762 const device T* x,
763 device T* y,
764 threadgroup T* Xs,
765 threadgroup T* Ws,
766 const constant int& K,
767 const constant int& N,
768 const constant int& M,
769 uint3 tid [[threadgroup_position_in_grid]],
770 uint lid [[thread_index_in_threadgroup]],
771 uint simd_gid [[simdgroup_index_in_threadgroup]],
772 uint simd_lid [[thread_index_in_simdgroup]]) {
773 static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
774 static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
775
776 (void)lid;
777
778 constexpr int WM = 2;
779 constexpr int WN = 2;
780 constexpr int pack_factor = 32 / bits;
781 constexpr int BK_padded = (BK + 16 / sizeof(T));
782
783 // Instantiate the appropriate BlockMMA and Loader
784 using mma_t = mlx::steel::
785 BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
786 using loader_x_t =
788 using loader_w_t = QuantizedBlockLoader<
789 T,
790 BN,
791 BK,
792 BK_padded,
793 1,
794 WM * WN * SIMD_SIZE,
795 group_size,
796 bits>;
797
798 // Set the block
799 const int K_w = K / pack_factor;
800 const int K_g = K / group_size;
801 const int y_row = tid.y * BM;
802 const int y_col = tid.x * BN;
803
804 x += y_row * K;
805 w += y_col * K_w;
806 scales += y_col * K_g;
807 biases += y_col * K_g;
808 y += y_row * N + y_col;
809
810 // Make the x loader and mma operation
811 const short num_els = min(BM, M - y_row);
812 const short num_outs = min(BN, N - y_col);
813 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
814 loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid);
815 mma_t mma_op(simd_gid, simd_lid);
816
817 if (num_els < BM) {
818 if (!aligned_N && num_outs < BN) {
819 for (int k = 0; k < K; k += BK) {
820 threadgroup_barrier(mem_flags::mem_threadgroup);
821 loader_x.load_safe(short2(BK, num_els));
822 loader_w.load_safe(short2(BK, num_outs));
823 threadgroup_barrier(mem_flags::mem_threadgroup);
824 mma_op.mma(Xs, Ws);
825 loader_x.next();
826 loader_w.next();
827 }
828 } else {
829 for (int k = 0; k < K; k += BK) {
830 threadgroup_barrier(mem_flags::mem_threadgroup);
831 loader_x.load_safe(short2(BK, num_els));
832 loader_w.load_unsafe();
833 threadgroup_barrier(mem_flags::mem_threadgroup);
834 mma_op.mma(Xs, Ws);
835 loader_x.next();
836 loader_w.next();
837 }
838 }
839 } else {
840 if (!aligned_N && num_outs < BN) {
841 for (int k = 0; k < K; k += BK) {
842 threadgroup_barrier(mem_flags::mem_threadgroup);
843 loader_x.load_unsafe();
844 loader_w.load_safe(short2(BK, num_outs));
845 threadgroup_barrier(mem_flags::mem_threadgroup);
846 mma_op.mma(Xs, Ws);
847 loader_x.next();
848 loader_w.next();
849 }
850 } else {
851 for (int k = 0; k < K; k += BK) {
852 threadgroup_barrier(mem_flags::mem_threadgroup);
853 loader_x.load_unsafe();
854 loader_w.load_unsafe();
855 threadgroup_barrier(mem_flags::mem_threadgroup);
856 mma_op.mma(Xs, Ws);
857 loader_x.next();
858 loader_w.next();
859 }
860 }
861 }
862
863 // Store results to device memory
864 threadgroup_barrier(mem_flags::mem_threadgroup);
865 if (num_els < BM || num_outs < BN) {
866 mma_op.store_result_safe(y, N, short2(num_outs, num_els));
867 } else {
868 mma_op.store_result(y, N);
869 }
870}
871
872template <
873 typename T,
874 const int group_size,
875 const int bits,
876 const int BM = 32,
877 const int BK = 32,
878 const int BN = 32>
879METAL_FUNC void qmm_n_impl(
880 const device uint32_t* w,
881 const device T* scales,
882 const device T* biases,
883 const device T* x,
884 device T* y,
885 threadgroup T* Xs,
886 threadgroup T* Ws,
887 const constant int& K,
888 const constant int& N,
889 const constant int& M,
890 uint3 tid [[threadgroup_position_in_grid]],
891 uint lid [[thread_index_in_threadgroup]],
892 uint simd_gid [[simdgroup_index_in_threadgroup]],
893 uint simd_lid [[thread_index_in_simdgroup]]) {
894 static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
895 static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
896
897 (void)lid;
898
899 constexpr int WM = 2;
900 constexpr int WN = 2;
901 constexpr int pack_factor = 32 / bits;
902 constexpr int BK_padded = (BK + 16 / sizeof(T));
903 constexpr int BN_padded = (BN + 16 / sizeof(T));
904
905 // Instantiate the appropriate BlockMMA and Loader
906 using mma_t = mlx::steel::
907 BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
908 using loader_x_t = mlx::steel::
909 BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
910 using loader_w_t = QuantizedBlockLoader<
911 T,
912 BK,
913 BN,
914 BN_padded,
915 0,
916 WM * WN * SIMD_SIZE,
917 group_size,
918 bits>;
919
920 // Set the block
921 const int y_row = tid.y * BM;
922 const int y_col = tid.x * BN;
923 x += y_row * K;
924 w += y_col / pack_factor;
925 scales += y_col / group_size;
926 biases += y_col / group_size;
927 y += y_row * N + y_col;
928
929 // Make the x loader and mma operation
930 const short num_els = min(BM, M - y_row);
931 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
932 loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid);
933 mma_t mma_op(simd_gid, simd_lid);
934
935 if (num_els < BM) {
936 if ((K % BK) != 0) {
937 const int k_blocks = K / BK;
938 for (int k = 0; k < k_blocks; k++) {
939 threadgroup_barrier(mem_flags::mem_threadgroup);
940 loader_x.load_safe(short2(BK, num_els));
941 loader_w.load_unsafe();
942 threadgroup_barrier(mem_flags::mem_threadgroup);
943 mma_op.mma(Xs, Ws);
944 loader_x.next();
945 loader_w.next();
946 }
947 const short num_k = K - k_blocks * BK;
948 threadgroup_barrier(mem_flags::mem_threadgroup);
949 loader_x.load_safe(short2(num_k, num_els));
950 loader_w.load_safe(short2(BN, num_k));
951 threadgroup_barrier(mem_flags::mem_threadgroup);
952 mma_op.mma(Xs, Ws);
953 } else {
954 for (int k = 0; k < K; k += BK) {
955 threadgroup_barrier(mem_flags::mem_threadgroup);
956 loader_x.load_safe(short2(BK, num_els));
957 loader_w.load_unsafe();
958 threadgroup_barrier(mem_flags::mem_threadgroup);
959 mma_op.mma(Xs, Ws);
960 loader_x.next();
961 loader_w.next();
962 }
963 }
964 } else {
965 if ((K % BK) != 0) {
966 const int k_blocks = K / BK;
967 for (int k = 0; k < k_blocks; k++) {
968 threadgroup_barrier(mem_flags::mem_threadgroup);
969 loader_x.load_unsafe();
970 loader_w.load_unsafe();
971 threadgroup_barrier(mem_flags::mem_threadgroup);
972 mma_op.mma(Xs, Ws);
973 loader_x.next();
974 loader_w.next();
975 }
976 const short num_k = K - k_blocks * BK;
977 threadgroup_barrier(mem_flags::mem_threadgroup);
978 loader_x.load_safe(short2(num_k, BM));
979 loader_w.load_safe(short2(BN, num_k));
980 threadgroup_barrier(mem_flags::mem_threadgroup);
981 mma_op.mma(Xs, Ws);
982 } else {
983 for (int k = 0; k < K; k += BK) {
984 threadgroup_barrier(mem_flags::mem_threadgroup);
985 loader_x.load_unsafe();
986 loader_w.load_unsafe();
987 threadgroup_barrier(mem_flags::mem_threadgroup);
988 mma_op.mma(Xs, Ws);
989 loader_x.next();
990 loader_w.next();
991 }
992 }
993 }
994
995 // Store results to device memory
996 threadgroup_barrier(mem_flags::mem_threadgroup);
997 if (num_els < BM) {
998 mma_op.store_result_safe(y, N, short2(BN, num_els));
999 } else {
1000 mma_op.store_result(y, N);
1001 }
1002}
1003
1004template <typename T>
1006 const device T*& x,
1007 const device uint32_t*& w,
1008 const device T*& scales,
1009 const device T*& biases,
1010 device T*& y,
1011 int output_stride,
1012 const constant int& x_batch_ndims,
1013 const constant int* x_shape,
1014 const constant size_t* x_strides,
1015 const constant int& w_batch_ndims,
1016 const constant int* w_shape,
1017 const constant size_t* w_strides,
1018 const constant size_t* s_strides,
1019 const constant size_t* b_strides,
1020 uint3 tid [[threadgroup_position_in_grid]]) {
1021 // Set the input/output matrices
1022 uint32_t x_idx = tid.z;
1023 uint32_t w_idx = tid.z;
1024 if (x_batch_ndims == 1) {
1025 x += x_idx * x_strides[0];
1026 } else {
1027 x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1028 }
1029 if (w_batch_ndims == 1) {
1030 w += w_idx * w_strides[0];
1031 scales += w_idx * s_strides[0];
1032 biases += w_idx * b_strides[0];
1033 } else {
1034 ulong3 idx = elem_to_loc_broadcast(
1035 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1036 w += idx.x;
1037 scales += idx.y;
1038 biases += idx.z;
1039 }
1040 y += tid.z * output_stride;
1041}
1042
1043template <typename T>
1045 const device T*& x,
1046 const device uint32_t*& w,
1047 const device T*& scales,
1048 const device T*& biases,
1049 const device uint32_t* lhs_indices,
1050 const device uint32_t* rhs_indices,
1051 device T*& y,
1052 int output_stride,
1053 const constant int& batch_ndims,
1054 const constant int* batch_shape,
1055 const constant size_t* lhs_strides,
1056 const constant size_t* rhs_strides,
1057 const constant int& x_batch_ndims,
1058 const constant int* x_shape,
1059 const constant size_t* x_strides,
1060 const constant int& w_batch_ndims,
1061 const constant int* w_shape,
1062 const constant size_t* w_strides,
1063 const constant size_t* s_strides,
1064 const constant size_t* b_strides,
1065 uint3 tid [[threadgroup_position_in_grid]]) {
1066 // Set the input/output matrices
1067 uint32_t x_idx;
1068 uint32_t w_idx;
1069 if (batch_ndims == 1) {
1070 x_idx = lhs_indices[tid.z * lhs_strides[0]];
1071 w_idx = rhs_indices[tid.z * rhs_strides[0]];
1072 } else {
1073 ulong2 idx = elem_to_loc_broadcast(
1074 tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
1075 x_idx = lhs_indices[idx.x];
1076 w_idx = rhs_indices[idx.y];
1077 }
1078 if (x_batch_ndims == 1) {
1079 x += x_idx * x_strides[0];
1080 } else {
1081 x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1082 }
1083 if (w_batch_ndims == 1) {
1084 w += w_idx * w_strides[0];
1085 scales += w_idx * s_strides[0];
1086 biases += w_idx * b_strides[0];
1087 } else {
1088 ulong3 idx = elem_to_loc_broadcast(
1089 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1090 w += idx.x;
1091 scales += idx.y;
1092 biases += idx.z;
1093 }
1094 y += tid.z * output_stride;
1095}
1096
1097template <typename T, int group_size, int bits, int D, bool batched>
1098[[kernel]] void qmv_quad(
1099 const device uint32_t* w [[buffer(0)]],
1100 const device T* scales [[buffer(1)]],
1101 const device T* biases [[buffer(2)]],
1102 const device T* x [[buffer(3)]],
1103 device T* y [[buffer(4)]],
1104 const constant int& in_vec_size [[buffer(5)]],
1105 const constant int& out_vec_size [[buffer(6)]],
1106 const constant int& x_batch_ndims [[buffer(7)]],
1107 const constant int* x_shape [[buffer(8)]],
1108 const constant size_t* x_strides [[buffer(9)]],
1109 const constant int& w_batch_ndims [[buffer(10)]],
1110 const constant int* w_shape [[buffer(11)]],
1111 const constant size_t* w_strides [[buffer(12)]],
1112 const constant size_t* s_strides [[buffer(13)]],
1113 const constant size_t* b_strides [[buffer(14)]],
1114 uint3 tid [[threadgroup_position_in_grid]],
1115 uint quad_gid [[quadgroup_index_in_threadgroup]],
1116 uint quad_lid [[thread_index_in_quadgroup]]) {
1117 if (batched) {
1119 x,
1120 w,
1121 scales,
1122 biases,
1123 y,
1124 out_vec_size,
1125 x_batch_ndims,
1126 x_shape,
1127 x_strides,
1128 w_batch_ndims,
1129 w_shape,
1130 w_strides,
1131 s_strides,
1132 b_strides,
1133 tid);
1134 }
1136 w,
1137 scales,
1138 biases,
1139 x,
1140 y,
1141 in_vec_size,
1142 out_vec_size,
1143 tid,
1144 quad_gid,
1145 quad_lid);
1146}
1147
1148template <typename T, int group_size, int bits, bool batched>
1149[[kernel]] void qmv_fast(
1150 const device uint32_t* w [[buffer(0)]],
1151 const device T* scales [[buffer(1)]],
1152 const device T* biases [[buffer(2)]],
1153 const device T* x [[buffer(3)]],
1154 device T* y [[buffer(4)]],
1155 const constant int& in_vec_size [[buffer(5)]],
1156 const constant int& out_vec_size [[buffer(6)]],
1157 const constant int& x_batch_ndims [[buffer(7)]],
1158 const constant int* x_shape [[buffer(8)]],
1159 const constant size_t* x_strides [[buffer(9)]],
1160 const constant int& w_batch_ndims [[buffer(10)]],
1161 const constant int* w_shape [[buffer(11)]],
1162 const constant size_t* w_strides [[buffer(12)]],
1163 const constant size_t* s_strides [[buffer(13)]],
1164 const constant size_t* b_strides [[buffer(14)]],
1165 uint3 tid [[threadgroup_position_in_grid]],
1166 uint simd_gid [[simdgroup_index_in_threadgroup]],
1167 uint simd_lid [[thread_index_in_simdgroup]]) {
1168 if (batched) {
1170 x,
1171 w,
1172 scales,
1173 biases,
1174 y,
1175 out_vec_size,
1176 x_batch_ndims,
1177 x_shape,
1178 x_strides,
1179 w_batch_ndims,
1180 w_shape,
1181 w_strides,
1182 s_strides,
1183 b_strides,
1184 tid);
1185 }
1187 w,
1188 scales,
1189 biases,
1190 x,
1191 y,
1192 in_vec_size,
1193 out_vec_size,
1194 tid,
1195 simd_gid,
1196 simd_lid);
1197}
1198
1199template <typename T, const int group_size, const int bits, bool batched>
1200[[kernel]] void qmv(
1201 const device uint32_t* w [[buffer(0)]],
1202 const device T* scales [[buffer(1)]],
1203 const device T* biases [[buffer(2)]],
1204 const device T* x [[buffer(3)]],
1205 device T* y [[buffer(4)]],
1206 const constant int& in_vec_size [[buffer(5)]],
1207 const constant int& out_vec_size [[buffer(6)]],
1208 const constant int& x_batch_ndims [[buffer(7)]],
1209 const constant int* x_shape [[buffer(8)]],
1210 const constant size_t* x_strides [[buffer(9)]],
1211 const constant int& w_batch_ndims [[buffer(10)]],
1212 const constant int* w_shape [[buffer(11)]],
1213 const constant size_t* w_strides [[buffer(12)]],
1214 const constant size_t* s_strides [[buffer(13)]],
1215 const constant size_t* b_strides [[buffer(14)]],
1216 uint3 tid [[threadgroup_position_in_grid]],
1217 uint simd_gid [[simdgroup_index_in_threadgroup]],
1218 uint simd_lid [[thread_index_in_simdgroup]]) {
1219 if (batched) {
1221 x,
1222 w,
1223 scales,
1224 biases,
1225 y,
1226 out_vec_size,
1227 x_batch_ndims,
1228 x_shape,
1229 x_strides,
1230 w_batch_ndims,
1231 w_shape,
1232 w_strides,
1233 s_strides,
1234 b_strides,
1235 tid);
1236 }
1238 w,
1239 scales,
1240 biases,
1241 x,
1242 y,
1243 in_vec_size,
1244 out_vec_size,
1245 tid,
1246 simd_gid,
1247 simd_lid);
1248}
1249
1250template <typename T, const int group_size, const int bits, bool batched>
1251[[kernel]] void qvm(
1252 const device uint32_t* w [[buffer(0)]],
1253 const device T* scales [[buffer(1)]],
1254 const device T* biases [[buffer(2)]],
1255 const device T* x [[buffer(3)]],
1256 device T* y [[buffer(4)]],
1257 const constant int& in_vec_size [[buffer(5)]],
1258 const constant int& out_vec_size [[buffer(6)]],
1259 const constant int& x_batch_ndims [[buffer(7)]],
1260 const constant int* x_shape [[buffer(8)]],
1261 const constant size_t* x_strides [[buffer(9)]],
1262 const constant int& w_batch_ndims [[buffer(10)]],
1263 const constant int* w_shape [[buffer(11)]],
1264 const constant size_t* w_strides [[buffer(12)]],
1265 const constant size_t* s_strides [[buffer(13)]],
1266 const constant size_t* b_strides [[buffer(14)]],
1267 uint3 tid [[threadgroup_position_in_grid]],
1268 uint simd_gid [[simdgroup_index_in_threadgroup]],
1269 uint simd_lid [[thread_index_in_simdgroup]]) {
1270 if (batched) {
1272 x,
1273 w,
1274 scales,
1275 biases,
1276 y,
1277 out_vec_size,
1278 x_batch_ndims,
1279 x_shape,
1280 x_strides,
1281 w_batch_ndims,
1282 w_shape,
1283 w_strides,
1284 s_strides,
1285 b_strides,
1286 tid);
1287 }
1289 w,
1290 scales,
1291 biases,
1292 x,
1293 y,
1294 in_vec_size,
1295 out_vec_size,
1296 tid,
1297 simd_gid,
1298 simd_lid);
1299}
1300
1301template <
1302 typename T,
1303 const int group_size,
1304 const int bits,
1305 const bool aligned_N,
1306 const bool batched,
1307 const int BM = 32,
1308 const int BK = 32,
1309 const int BN = 32>
1310[[kernel]] void qmm_t(
1311 const device uint32_t* w [[buffer(0)]],
1312 const device T* scales [[buffer(1)]],
1313 const device T* biases [[buffer(2)]],
1314 const device T* x [[buffer(3)]],
1315 device T* y [[buffer(4)]],
1316 const constant int& K [[buffer(5)]],
1317 const constant int& N [[buffer(6)]],
1318 const constant int& M [[buffer(7)]],
1319 const constant int& x_batch_ndims [[buffer(8)]],
1320 const constant int* x_shape [[buffer(9)]],
1321 const constant size_t* x_strides [[buffer(10)]],
1322 const constant int& w_batch_ndims [[buffer(11)]],
1323 const constant int* w_shape [[buffer(12)]],
1324 const constant size_t* w_strides [[buffer(13)]],
1325 const constant size_t* s_strides [[buffer(14)]],
1326 const constant size_t* b_strides [[buffer(15)]],
1327 uint3 tid [[threadgroup_position_in_grid]],
1328 uint lid [[thread_index_in_threadgroup]],
1329 uint simd_gid [[simdgroup_index_in_threadgroup]],
1330 uint simd_lid [[thread_index_in_simdgroup]]) {
1331 (void)lid;
1332
1333 constexpr int BK_padded = (BK + 16 / sizeof(T));
1334
1335 threadgroup T Xs[BM * BK_padded];
1336 threadgroup T Ws[BN * BK_padded];
1337
1338 if (batched) {
1340 x,
1341 w,
1342 scales,
1343 biases,
1344 y,
1345 M * N,
1346 x_batch_ndims,
1347 x_shape,
1348 x_strides,
1349 w_batch_ndims,
1350 w_shape,
1351 w_strides,
1352 s_strides,
1353 b_strides,
1354 tid);
1355 }
1357 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1358}
1359
1360template <
1361 typename T,
1362 const int group_size,
1363 const int bits,
1364 const bool batched,
1365 const int BM = 32,
1366 const int BK = 32,
1367 const int BN = 32>
1368[[kernel]] void qmm_n(
1369 const device uint32_t* w [[buffer(0)]],
1370 const device T* scales [[buffer(1)]],
1371 const device T* biases [[buffer(2)]],
1372 const device T* x [[buffer(3)]],
1373 device T* y [[buffer(4)]],
1374 const constant int& K [[buffer(5)]],
1375 const constant int& N [[buffer(6)]],
1376 const constant int& M [[buffer(7)]],
1377 const constant int& x_batch_ndims [[buffer(8)]],
1378 const constant int* x_shape [[buffer(9)]],
1379 const constant size_t* x_strides [[buffer(10)]],
1380 const constant int& w_batch_ndims [[buffer(11)]],
1381 const constant int* w_shape [[buffer(12)]],
1382 const constant size_t* w_strides [[buffer(13)]],
1383 const constant size_t* s_strides [[buffer(14)]],
1384 const constant size_t* b_strides [[buffer(15)]],
1385 uint3 tid [[threadgroup_position_in_grid]],
1386 uint lid [[thread_index_in_threadgroup]],
1387 uint simd_gid [[simdgroup_index_in_threadgroup]],
1388 uint simd_lid [[thread_index_in_simdgroup]]) {
1389 (void)lid;
1390
1391 constexpr int BK_padded = (BK + 16 / sizeof(T));
1392 constexpr int BN_padded = (BN + 16 / sizeof(T));
1393
1394 threadgroup T Xs[BM * BK_padded];
1395 threadgroup T Ws[BK * BN_padded];
1396
1397 if (batched) {
1399 x,
1400 w,
1401 scales,
1402 biases,
1403 y,
1404 M * N,
1405 x_batch_ndims,
1406 x_shape,
1407 x_strides,
1408 w_batch_ndims,
1409 w_shape,
1410 w_strides,
1411 s_strides,
1412 b_strides,
1413 tid);
1414 }
1415
1417 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1418}
1419
1420template <typename T, int group_size, int bits>
1421[[kernel]] void bs_qmv_fast(
1422 const device uint32_t* w [[buffer(0)]],
1423 const device T* scales [[buffer(1)]],
1424 const device T* biases [[buffer(2)]],
1425 const device T* x [[buffer(3)]],
1426 device T* y [[buffer(4)]],
1427 const constant int& in_vec_size [[buffer(5)]],
1428 const constant int& out_vec_size [[buffer(6)]],
1429 const constant int& x_batch_ndims [[buffer(7)]],
1430 const constant int* x_shape [[buffer(8)]],
1431 const constant size_t* x_strides [[buffer(9)]],
1432 const constant int& w_batch_ndims [[buffer(10)]],
1433 const constant int* w_shape [[buffer(11)]],
1434 const constant size_t* w_strides [[buffer(12)]],
1435 const constant size_t* s_strides [[buffer(13)]],
1436 const constant size_t* b_strides [[buffer(14)]],
1437 const constant int& batch_ndims [[buffer(15)]],
1438 const constant int* batch_shape [[buffer(16)]],
1439 const device uint32_t* lhs_indices [[buffer(17)]],
1440 const device uint32_t* rhs_indices [[buffer(18)]],
1441 const constant size_t* lhs_strides [[buffer(19)]],
1442 const constant size_t* rhs_strides [[buffer(20)]],
1443 uint3 tid [[threadgroup_position_in_grid]],
1444 uint simd_gid [[simdgroup_index_in_threadgroup]],
1445 uint simd_lid [[thread_index_in_simdgroup]]) {
1447 x,
1448 w,
1449 scales,
1450 biases,
1451 lhs_indices,
1452 rhs_indices,
1453 y,
1454 out_vec_size,
1455 batch_ndims,
1456 batch_shape,
1457 lhs_strides,
1458 rhs_strides,
1459 x_batch_ndims,
1460 x_shape,
1461 x_strides,
1462 w_batch_ndims,
1463 w_shape,
1464 w_strides,
1465 s_strides,
1466 b_strides,
1467 tid);
1469 w,
1470 scales,
1471 biases,
1472 x,
1473 y,
1474 in_vec_size,
1475 out_vec_size,
1476 tid,
1477 simd_gid,
1478 simd_lid);
1479}
1480
1481template <typename T, int group_size, int bits>
1482[[kernel]] void bs_qmv(
1483 const device uint32_t* w [[buffer(0)]],
1484 const device T* scales [[buffer(1)]],
1485 const device T* biases [[buffer(2)]],
1486 const device T* x [[buffer(3)]],
1487 device T* y [[buffer(4)]],
1488 const constant int& in_vec_size [[buffer(5)]],
1489 const constant int& out_vec_size [[buffer(6)]],
1490 const constant int& x_batch_ndims [[buffer(7)]],
1491 const constant int* x_shape [[buffer(8)]],
1492 const constant size_t* x_strides [[buffer(9)]],
1493 const constant int& w_batch_ndims [[buffer(10)]],
1494 const constant int* w_shape [[buffer(11)]],
1495 const constant size_t* w_strides [[buffer(12)]],
1496 const constant size_t* s_strides [[buffer(13)]],
1497 const constant size_t* b_strides [[buffer(14)]],
1498 const constant int& batch_ndims [[buffer(15)]],
1499 const constant int* batch_shape [[buffer(16)]],
1500 const device uint32_t* lhs_indices [[buffer(17)]],
1501 const device uint32_t* rhs_indices [[buffer(18)]],
1502 const constant size_t* lhs_strides [[buffer(19)]],
1503 const constant size_t* rhs_strides [[buffer(20)]],
1504 uint3 tid [[threadgroup_position_in_grid]],
1505 uint simd_gid [[simdgroup_index_in_threadgroup]],
1506 uint simd_lid [[thread_index_in_simdgroup]]) {
1508 x,
1509 w,
1510 scales,
1511 biases,
1512 lhs_indices,
1513 rhs_indices,
1514 y,
1515 out_vec_size,
1516 batch_ndims,
1517 batch_shape,
1518 lhs_strides,
1519 rhs_strides,
1520 x_batch_ndims,
1521 x_shape,
1522 x_strides,
1523 w_batch_ndims,
1524 w_shape,
1525 w_strides,
1526 s_strides,
1527 b_strides,
1528 tid);
1530 w,
1531 scales,
1532 biases,
1533 x,
1534 y,
1535 in_vec_size,
1536 out_vec_size,
1537 tid,
1538 simd_gid,
1539 simd_lid);
1540}
1541
1542template <typename T, int group_size, int bits>
1543[[kernel]] void bs_qvm(
1544 const device uint32_t* w [[buffer(0)]],
1545 const device T* scales [[buffer(1)]],
1546 const device T* biases [[buffer(2)]],
1547 const device T* x [[buffer(3)]],
1548 device T* y [[buffer(4)]],
1549 const constant int& in_vec_size [[buffer(5)]],
1550 const constant int& out_vec_size [[buffer(6)]],
1551 const constant int& x_batch_ndims [[buffer(7)]],
1552 const constant int* x_shape [[buffer(8)]],
1553 const constant size_t* x_strides [[buffer(9)]],
1554 const constant int& w_batch_ndims [[buffer(10)]],
1555 const constant int* w_shape [[buffer(11)]],
1556 const constant size_t* w_strides [[buffer(12)]],
1557 const constant size_t* s_strides [[buffer(13)]],
1558 const constant size_t* b_strides [[buffer(14)]],
1559 const constant int& batch_ndims [[buffer(15)]],
1560 const constant int* batch_shape [[buffer(16)]],
1561 const device uint32_t* lhs_indices [[buffer(17)]],
1562 const device uint32_t* rhs_indices [[buffer(18)]],
1563 const constant size_t* lhs_strides [[buffer(19)]],
1564 const constant size_t* rhs_strides [[buffer(20)]],
1565 uint3 tid [[threadgroup_position_in_grid]],
1566 uint simd_gid [[simdgroup_index_in_threadgroup]],
1567 uint simd_lid [[thread_index_in_simdgroup]]) {
1569 x,
1570 w,
1571 scales,
1572 biases,
1573 lhs_indices,
1574 rhs_indices,
1575 y,
1576 out_vec_size,
1577 batch_ndims,
1578 batch_shape,
1579 lhs_strides,
1580 rhs_strides,
1581 x_batch_ndims,
1582 x_shape,
1583 x_strides,
1584 w_batch_ndims,
1585 w_shape,
1586 w_strides,
1587 s_strides,
1588 b_strides,
1589 tid);
1591 w,
1592 scales,
1593 biases,
1594 x,
1595 y,
1596 in_vec_size,
1597 out_vec_size,
1598 tid,
1599 simd_gid,
1600 simd_lid);
1601}
1602
1603template <
1604 typename T,
1605 const int group_size,
1606 const int bits,
1607 const bool aligned_N,
1608 const int BM = 32,
1609 const int BK = 32,
1610 const int BN = 32>
1611[[kernel]] void bs_qmm_t(
1612 const device uint32_t* w [[buffer(0)]],
1613 const device T* scales [[buffer(1)]],
1614 const device T* biases [[buffer(2)]],
1615 const device T* x [[buffer(3)]],
1616 device T* y [[buffer(4)]],
1617 const constant int& K [[buffer(5)]],
1618 const constant int& N [[buffer(6)]],
1619 const constant int& M [[buffer(7)]],
1620 const constant int& x_batch_ndims [[buffer(8)]],
1621 const constant int* x_shape [[buffer(9)]],
1622 const constant size_t* x_strides [[buffer(10)]],
1623 const constant int& w_batch_ndims [[buffer(11)]],
1624 const constant int* w_shape [[buffer(12)]],
1625 const constant size_t* w_strides [[buffer(13)]],
1626 const constant size_t* s_strides [[buffer(14)]],
1627 const constant size_t* b_strides [[buffer(15)]],
1628 const constant int& batch_ndims [[buffer(16)]],
1629 const constant int* batch_shape [[buffer(17)]],
1630 const device uint32_t* lhs_indices [[buffer(18)]],
1631 const device uint32_t* rhs_indices [[buffer(19)]],
1632 const constant size_t* lhs_strides [[buffer(20)]],
1633 const constant size_t* rhs_strides [[buffer(21)]],
1634 uint3 tid [[threadgroup_position_in_grid]],
1635 uint lid [[thread_index_in_threadgroup]],
1636 uint simd_gid [[simdgroup_index_in_threadgroup]],
1637 uint simd_lid [[thread_index_in_simdgroup]]) {
1638 (void)lid;
1639
1640 constexpr int BK_padded = (BK + 16 / sizeof(T));
1641
1642 threadgroup T Xs[BM * BK_padded];
1643 threadgroup T Ws[BN * BK_padded];
1644
1646 x,
1647 w,
1648 scales,
1649 biases,
1650 lhs_indices,
1651 rhs_indices,
1652 y,
1653 M * N,
1654 batch_ndims,
1655 batch_shape,
1656 lhs_strides,
1657 rhs_strides,
1658 x_batch_ndims,
1659 x_shape,
1660 x_strides,
1661 w_batch_ndims,
1662 w_shape,
1663 w_strides,
1664 s_strides,
1665 b_strides,
1666 tid);
1668 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1669}
1670
1671template <
1672 typename T,
1673 const int group_size,
1674 const int bits,
1675 const int BM = 32,
1676 const int BK = 32,
1677 const int BN = 32>
1678[[kernel]] void bs_qmm_n(
1679 const device uint32_t* w [[buffer(0)]],
1680 const device T* scales [[buffer(1)]],
1681 const device T* biases [[buffer(2)]],
1682 const device T* x [[buffer(3)]],
1683 device T* y [[buffer(4)]],
1684 const constant int& K [[buffer(5)]],
1685 const constant int& N [[buffer(6)]],
1686 const constant int& M [[buffer(7)]],
1687 const constant int& x_batch_ndims [[buffer(8)]],
1688 const constant int* x_shape [[buffer(9)]],
1689 const constant size_t* x_strides [[buffer(10)]],
1690 const constant int& w_batch_ndims [[buffer(11)]],
1691 const constant int* w_shape [[buffer(12)]],
1692 const constant size_t* w_strides [[buffer(13)]],
1693 const constant size_t* s_strides [[buffer(14)]],
1694 const constant size_t* b_strides [[buffer(15)]],
1695 const constant int& batch_ndims [[buffer(16)]],
1696 const constant int* batch_shape [[buffer(17)]],
1697 const device uint32_t* lhs_indices [[buffer(18)]],
1698 const device uint32_t* rhs_indices [[buffer(19)]],
1699 const constant size_t* lhs_strides [[buffer(20)]],
1700 const constant size_t* rhs_strides [[buffer(21)]],
1701 uint3 tid [[threadgroup_position_in_grid]],
1702 uint lid [[thread_index_in_threadgroup]],
1703 uint simd_gid [[simdgroup_index_in_threadgroup]],
1704 uint simd_lid [[thread_index_in_simdgroup]]) {
1705 (void)lid;
1706
1707 constexpr int BK_padded = (BK + 16 / sizeof(T));
1708 constexpr int BN_padded = (BN + 16 / sizeof(T));
1709
1710 threadgroup T Xs[BM * BK_padded];
1711 threadgroup T Ws[BK * BN_padded];
1712
1714 x,
1715 w,
1716 scales,
1717 biases,
1718 lhs_indices,
1719 rhs_indices,
1720 y,
1721 M * N,
1722 batch_ndims,
1723 batch_shape,
1724 lhs_strides,
1725 rhs_strides,
1726 x_batch_ndims,
1727 x_shape,
1728 x_strides,
1729 w_batch_ndims,
1730 w_shape,
1731 w_strides,
1732 s_strides,
1733 b_strides,
1734 tid);
1736 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1737}
1738
1739template <typename T, const int group_size, const int bits>
1740[[kernel]] void affine_quantize(
1741 const device T* w [[buffer(0)]],
1742 device uint8_t* out [[buffer(1)]],
1743 device T* scales [[buffer(2)]],
1744 device T* biases [[buffer(3)]],
1745 uint2 index [[thread_position_in_grid]],
1746 uint2 grid_dim [[threads_per_grid]]) {
1747 constexpr T eps = T(1e-7);
1748 constexpr int simd_size = 32;
1749 constexpr int uint8_bits = 8;
1750 constexpr T n_bins = (1 << bits) - 1;
1751 constexpr int packs_per_int = uint8_bits / bits;
1752 constexpr int values_per_reduce = group_size / simd_size;
1753 constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
1754 constexpr int writes_per_pack =
1755 writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
1756
1757 static_assert(
1758 group_size % simd_size == 0,
1759 "Group size must be divisible by simd size.");
1760
1761 size_t offset = index.x + grid_dim.x * size_t(index.y);
1762 size_t in_index = offset * values_per_reduce;
1763 size_t out_index = offset * writes_per_pack;
1764
1765 T w_thread[values_per_reduce];
1766 T w_min = Limits<T>::max;
1767 T w_max = 0;
1768
1769#pragma clang loop unroll(full)
1770 for (int i = 0; i < values_per_reduce; i++) {
1771 T val = w[in_index + i];
1772 w_thread[i] = val;
1773 w_min = min(w_min, val);
1774 w_max = max(w_max, val);
1775 }
1776
1777 w_min = simd_min(w_min);
1778 w_max = simd_max(w_max);
1779
1780 T scale = max((w_max - w_min) / n_bins, eps);
1781 bool side = abs(w_min) > abs(w_max);
1782 scale = side ? scale : -scale;
1783 T edge = side ? w_min : w_max;
1784 T q0 = round(edge / scale);
1785 bool at_zero = q0 == 0.0f;
1786 scale = at_zero ? scale : edge / q0;
1787 T bias = at_zero ? T(0) : edge;
1788
1789 // Write out the scales and biases
1790 size_t gindex = in_index / group_size;
1791 if (in_index % group_size == 0) {
1792 scales[gindex] = scale;
1793 biases[gindex] = bias;
1794 }
1795
1796 uint8_t output = 0;
1797#pragma clang loop unroll(full)
1798 for (int i = 0; i < values_per_reduce; i++) {
1799 uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
1800 if (bits == 8) {
1801 output = val;
1802 } else {
1803 output += val << (bits * (i % packs_per_int));
1804 }
1805
1806 if (packs_per_int < values_per_reduce &&
1807 i % packs_per_int == packs_per_int - 1) {
1808 out[out_index + i / packs_per_int] = output;
1809 output = 0;
1810 } else {
1811#pragma clang loop unroll(full)
1812 for (int j = 0; j < writes_per_reduce - 1; j++) {
1813 uint8_t sval = simd_shuffle_down(val, j + 1);
1814 output += sval << (bits * (values_per_reduce + j + i));
1815 }
1816 }
1817 }
1818 if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
1819 out[out_index / writes_per_reduce] = output;
1820 }
1821}
1822
1823template <typename T, const int group_size, const int bits>
1825 const device T* w [[buffer(0)]],
1826 const device T* scales [[buffer(1)]],
1827 const device T* biases [[buffer(2)]],
1828 device uint8_t* out [[buffer(3)]],
1829 uint2 index [[thread_position_in_grid]],
1830 uint2 grid_dim [[threads_per_grid]]) {
1831 constexpr int uint8_bits = 8;
1832 constexpr int packs_per_int = uint8_bits / bits;
1833 constexpr T n_bins = (1 << bits) - 1;
1834
1835 size_t offset = index.x + grid_dim.x * size_t(index.y);
1836 size_t in_index = offset * packs_per_int;
1837 size_t gindex = in_index / group_size;
1838
1839 T scale = scales[gindex];
1840 T bias = biases[gindex];
1841
1842 uint8_t output = 0;
1843#pragma clang loop unroll(full)
1844 for (int i = 0; i < packs_per_int; i++) {
1845 uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins);
1846 if (bits == 8) {
1847 output = val;
1848 } else {
1849 output += val << (bits * i);
1850 }
1851 }
1852 out[offset] = output;
1853}
1854
1855template <typename T, const int group_size, const int bits>
1856[[kernel]] void affine_dequantize(
1857 const device uint8_t* w [[buffer(0)]],
1858 const device T* scales [[buffer(1)]],
1859 const device T* biases [[buffer(2)]],
1860 device T* out [[buffer(3)]],
1861 uint2 index [[thread_position_in_grid]],
1862 uint2 grid_dim [[threads_per_grid]]) {
1863 constexpr int uint8_bits = 8;
1864 constexpr int packs_per_int = uint8_bits / bits;
1865
1866 size_t offset = index.x + grid_dim.x * size_t(index.y);
1867 size_t oindex = offset * packs_per_int;
1868 size_t gindex = oindex / group_size;
1869 T scale = scales[gindex];
1870 T bias = biases[gindex];
1871 uint val = w[offset];
1872
1873#pragma clang loop unroll(full)
1874 for (int i = 0; i < packs_per_int; i++) {
1875 uint8_t d;
1876 if (bits == 2) {
1877 d = (val >> (bits * i)) & 0x03;
1878 } else if (bits == 4) {
1879 d = (val >> (bits * i)) & 0x0f;
1880 } else if (bits == 8) {
1881 d = val;
1882 }
1883 out[oindex + i] = scale * d + bias;
1884 }
1885}
static constant constexpr const uint8_t simd_size
Definition ops.h:22
METAL_FUNC ulong2 elem_to_loc_broadcast(uint elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
Definition utils.h:7
METAL_FUNC stride_t elem_to_loc(uint elem, constant const int *shape, constant const stride_t *strides, int ndim)
Definition utils.h:87
Definition bf16.h:265
METAL_FUNC bfloat16_t simd_max(bfloat16_t data)
Definition bf16_math.h:392
METAL_FUNC bfloat16_t round(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t simd_sum(bfloat16_t data)
Definition bf16_math.h:392
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t abs(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t simd_min(bfloat16_t data)
Definition bf16_math.h:392
METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta)
Definition bf16_math.h:391
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
#define MLX_MTL_CONST
Definition quantized.h:8
U qdot_safe(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum, int N)
Definition quantized.h:142
METAL_FUNC void qmm_n_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:879
void bs_qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1678
void qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1368
void affine_quantize(const device T *w, device uint8_t *out, device T *scales, device T *biases, uint2 index, uint2 grid_dim)
Definition quantized.h:1740
METAL_FUNC void qvm_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:647
void bs_qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1421
void affine_dequantize(const device uint8_t *w, const device T *scales, const device T *biases, device T *out, uint2 index, uint2 grid_dim)
Definition quantized.h:1856
static constant constexpr const int SIMD_SIZE
Definition quantized.h:10
void qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1200
void bs_qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1543
void affine_quantize_scales_biases(const device T *w, const device T *scales, const device T *biases, device uint8_t *out, uint2 index, uint2 grid_dim)
Definition quantized.h:1824
void qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1149
void qmv_quad(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:1098
static constant constexpr const int QUAD_SIZE
Definition quantized.h:11
U load_vector(const device T *x, thread U *x_thread)
Definition quantized.h:14
METAL_FUNC void qmv_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:498
U load_vector_safe(const device T *x, thread U *x_thread, int N)
Definition quantized.h:52
void bs_qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1611
U qdot(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum)
Definition quantized.h:99
METAL_FUNC void qmv_fast_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:434
void qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1310
METAL_FUNC void adjust_matrix_offsets(const device T *&x, const device uint32_t *&w, const device T *&scales, const device T *&biases, device T *&y, int output_stride, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid)
Definition quantized.h:1005
void bs_qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1482
METAL_FUNC void qmv_quad_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:376
void qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1251
void qouter(const thread uint8_t *w, U x, U scale, U bias, thread U *result)
Definition quantized.h:187
void dequantize(const device uint8_t *w, U scale, U bias, threadgroup U *w_local)
Definition quantized.h:219
METAL_FUNC void qmm_t_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:758
Definition utils.h:17
Definition quantized.h:262
const int group_stride
Definition quantized.h:282
static constant constexpr const short BCOLS_PACKED
Definition quantized.h:274
const device T * biases
Definition quantized.h:291
short group_step_cnt
Definition quantized.h:281
static constant constexpr const short group_steps
Definition quantized.h:277
const short thread_idx
Definition quantized.h:284
const device T * scales
Definition quantized.h:290
static constant constexpr const short n_reads
Definition quantized.h:275
void next()
Definition quantized.h:354
void load_safe(short2 src_tile_dim) const
Definition quantized.h:327
const int src_ld
Definition quantized.h:279
const short bi
Definition quantized.h:285
void load_unsafe() const
Definition quantized.h:314
static constant constexpr const short pack_factor
Definition quantized.h:273
threadgroup T * dst
Definition quantized.h:288
const int tile_stride
Definition quantized.h:280
const device uint32_t * src
Definition quantized.h:289
const short bj
Definition quantized.h:286
QuantizedBlockLoader(const device uint32_t *src_, const device T *scales_, const device T *biases_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition quantized.h:293
Definition loader.h:25