3#include <metal_simdgroup>
8#define MLX_MTL_CONST static constant constexpr const
13template <
typename T,
typename U,
int values_per_thread,
int bits>
16 bits == 2 || bits == 4 || bits == 8,
17 "Template undefined for bits not in {2, 4, 8}");
22 for (
int i = 0; i < values_per_thread; i += 4) {
23 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
25 x_thread[i + 1] = x[i + 1] / 4.0f;
26 x_thread[i + 2] = x[i + 2] / 16.0f;
27 x_thread[i + 3] = x[i + 3] / 64.0f;
32 for (
int i = 0; i < values_per_thread; i += 4) {
33 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
35 x_thread[i + 1] = x[i + 1] / 16.0f;
36 x_thread[i + 2] = x[i + 2] / 256.0f;
37 x_thread[i + 3] = x[i + 3] / 4096.0f;
42 for (
int i = 0; i < values_per_thread; i++) {
51template <
typename T,
typename U,
int values_per_thread,
int bits>
54 bits == 2 || bits == 4 || bits == 8,
55 "Template undefined for bits not in {2, 4, 8}");
60 for (
int i = 0; i < N; i += 4) {
61 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
63 x_thread[i + 1] = x[i + 1] / 4.0f;
64 x_thread[i + 2] = x[i + 2] / 16.0f;
65 x_thread[i + 3] = x[i + 3] / 64.0f;
67 for (
int i = N; i < values_per_thread; i++) {
73 for (
int i = 0; i < N; i += 4) {
74 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
76 x_thread[i + 1] = x[i + 1] / 16.0f;
77 x_thread[i + 2] = x[i + 2] / 256.0f;
78 x_thread[i + 3] = x[i + 3] / 4096.0f;
80 for (
int i = N; i < values_per_thread; i++) {
86 for (
int i = 0; i < N; i++) {
90 for (
int i = N; i < values_per_thread; i++) {
98template <
typename U,
int values_per_thread,
int bits>
100 const device uint8_t* w,
101 const thread U* x_thread,
106 bits == 2 || bits == 4 || bits == 8,
107 "Template undefined for bits not in {2, 4, 8}");
112 for (
int i = 0; i < (values_per_thread / 4); i++) {
114 (x_thread[4 * i] * (w[i] & 0x03) +
115 x_thread[4 * i + 1] * (w[i] & 0x0c) +
116 x_thread[4 * i + 2] * (w[i] & 0x30) +
117 x_thread[4 * i + 3] * (w[i] & 0xc0));
121 else if (bits == 4) {
122 const device uint16_t* ws = (
const device uint16_t*)w;
123 for (
int i = 0; i < (values_per_thread / 4); i++) {
125 (x_thread[4 * i] * (ws[i] & 0x000f) +
126 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
127 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
128 x_thread[4 * i + 3] * (ws[i] & 0xf000));
132 else if (bits == 8) {
133 for (
int i = 0; i < values_per_thread; i++) {
134 accum += x_thread[i] * w[i];
138 return scale * accum + sum * bias;
141template <
typename U,
int values_per_thread,
int bits>
143 const device uint8_t* w,
144 const thread U* x_thread,
150 bits == 2 || bits == 4 || bits == 8,
151 "Template undefined for bits not in {2, 4, 8}");
156 for (
int i = 0; i < (N / 4); i++) {
158 (x_thread[4 * i] * (w[i] & 0x03) +
159 x_thread[4 * i + 1] * (w[i] & 0x0c) +
160 x_thread[4 * i + 2] * (w[i] & 0x30) +
161 x_thread[4 * i + 3] * (w[i] & 0xc0));
165 else if (bits == 4) {
166 const device uint16_t* ws = (
const device uint16_t*)w;
167 for (
int i = 0; i < (N / 4); i++) {
169 (x_thread[4 * i] * (ws[i] & 0x000f) +
170 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
171 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
172 x_thread[4 * i + 3] * (ws[i] & 0xf000));
176 else if (bits == 8) {
177 for (
int i = 0; i < N; i++) {
178 accum += x_thread[i] * w[i];
182 return scale * accum + sum * bias;
185template <
typename U,
int values_per_thread,
int bits>
187qouter(
const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
189 bits == 2 || bits == 4 || bits == 8,
190 "Template undefined for bits not in {2, 4, 8}");
193 U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
194 for (
int i = 0; i < (values_per_thread / 4); i++) {
195 result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
196 result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
197 result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
198 result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
202 else if (bits == 4) {
203 U s[2] = {scale, scale / 16.0f};
204 for (
int i = 0; i < (values_per_thread / 2); i++) {
205 result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
206 result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
210 else if (bits == 8) {
211 for (
int i = 0; i < values_per_thread; i++) {
212 result[i] += x * (scale * w[i] + bias);
217template <
typename U,
int N,
int bits>
219dequantize(
const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
221 bits == 2 || bits == 4 || bits == 8,
222 "Template undefined for bits not in {2, 4, 8}");
227 scale /
static_cast<U
>(4.0f),
228 scale /
static_cast<U
>(16.0f),
229 scale /
static_cast<U
>(64.0f)};
230 for (
int i = 0; i < (N / 4); i++) {
231 w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
232 w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
233 w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
234 w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
238 else if (bits == 4) {
239 U s[2] = {scale, scale /
static_cast<U
>(16.0f)};
240 for (
int i = 0; i < (N / 2); i++) {
241 w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
242 w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
246 else if (bits == 8) {
247 for (
int i = 0; i < N; i++) {
248 w_local[i] = scale * w[i] + bias;
265 "The group size should be larger than the columns");
267 group_size % BCOLS == 0,
268 "The group size should be divisible by the columns");
270 bits == 2 || bits == 4 || bits == 8,
271 "Template undefined for bits not in {2, 4, 8}");
289 const device uint32_t*
src;
294 const device uint32_t* src_,
295 const device T* scales_,
296 const device T* biases_,
299 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
300 ushort simd_lane_id [[thread_index_in_simdgroup]])
306 thread_idx(simd_group_id * 32 + simd_lane_id),
321 for (
int i = 0; i <
n_reads; i++) {
332 if (reduction_dim == 1 &&
bi >= src_tile_dim.y) {
339 if (reduction_dim == 0 &&
bi >= src_tile_dim.x) {
348 for (
int i = 0; i <
n_reads; i++) {
356 if (reduction_dim == 1) {
375template <
typename T,
int group_size,
int bits,
int D>
377 const device uint32_t* w,
378 const device T* scales,
379 const device T* biases,
382 constant
int& in_vec_size,
383 const constant
int& out_vec_size,
384 uint3 tid [[threadgroup_position_in_grid]],
385 uint quad_gid [[quadgroup_index_in_threadgroup]],
386 uint quad_lid [[thread_index_in_quadgroup]]) {
388 constexpr int pack_factor = 32 / bits;
389 constexpr int values_per_thread = D /
QUAD_SIZE;
390 constexpr int packs_per_thread = values_per_thread / pack_factor;
391 constexpr int scale_step_per_thread = group_size / values_per_thread;
392 constexpr int results_per_quadgroup = 8;
396 thread U x_thread[values_per_thread];
397 thread U result[results_per_quadgroup] = {0};
400 const int in_vec_size_w = in_vec_size / pack_factor;
401 const int in_vec_size_g = in_vec_size / group_size;
402 const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
404 w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
405 scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
406 biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
407 x += tid.y * in_vec_size + quad_lid * values_per_thread;
408 y += tid.y * out_vec_size + out_row;
412 for (
int row = 0; row < results_per_quadgroup; row++) {
413 const device uint8_t* wl =
414 (
const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
415 const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
416 const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
420 if (row * quads_per_simd + out_row < out_vec_size) {
425 for (
int row = 0; row < results_per_quadgroup; row++) {
426 result[row] = quad_sum(result[row]);
427 if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
428 y[row * quads_per_simd] =
static_cast<T
>(result[row]);
433template <
typename T,
int group_size,
int bits>
435 const device uint32_t* w,
436 const device T* scales,
437 const device T* biases,
440 const constant
int& in_vec_size,
441 const constant
int& out_vec_size,
442 uint3 tid [[threadgroup_position_in_grid]],
443 uint simd_gid [[simdgroup_index_in_threadgroup]],
444 uint simd_lid [[thread_index_in_simdgroup]]) {
445 constexpr int packs_per_thread = bits > 2 ? 2 : 1;
446 constexpr int num_simdgroups = 2;
447 constexpr int results_per_simdgroup = 4;
448 constexpr int pack_factor = 32 / bits;
449 constexpr int values_per_thread = pack_factor * packs_per_thread;
450 constexpr int block_size = values_per_thread *
SIMD_SIZE;
451 constexpr int scale_step_per_thread = group_size / values_per_thread;
455 thread U x_thread[values_per_thread];
456 thread U result[results_per_simdgroup] = {0};
459 const int in_vec_size_w = in_vec_size / pack_factor;
460 const int in_vec_size_g = in_vec_size / group_size;
461 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
462 simd_gid * results_per_simdgroup;
463 w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
464 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
465 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
466 x += tid.y * in_vec_size + simd_lid * values_per_thread;
467 y += tid.y * out_vec_size + out_row;
469 for (
int k = 0; k < in_vec_size; k += block_size) {
472 for (
int row = 0; row < results_per_simdgroup; row++) {
473 const device uint8_t* wl =
474 (
const device uint8_t*)(w + row * in_vec_size_w);
475 const device T* sl = scales + row * in_vec_size_g;
476 const device T* bl = biases + row * in_vec_size_g;
483 w += block_size / pack_factor;
484 scales += block_size / group_size;
485 biases += block_size / group_size;
489 for (
int row = 0; row < results_per_simdgroup; row++) {
490 result[row] =
simd_sum(result[row]);
492 y[row] =
static_cast<T
>(result[row]);
497template <
typename T,
int group_size,
int bits>
499 const device uint32_t* w,
500 const device T* scales,
501 const device T* biases,
504 const constant
int& in_vec_size,
505 const constant
int& out_vec_size,
506 uint3 tid [[threadgroup_position_in_grid]],
507 uint simd_gid [[simdgroup_index_in_threadgroup]],
508 uint simd_lid [[thread_index_in_simdgroup]]) {
509 constexpr int num_simdgroups = 2;
510 constexpr int results_per_simdgroup = 4;
511 constexpr int packs_per_thread = 1;
512 constexpr int pack_factor = 32 / bits;
513 constexpr int values_per_thread = pack_factor * packs_per_thread;
514 constexpr int block_size = values_per_thread *
SIMD_SIZE;
515 constexpr int scale_step_per_thread = group_size / values_per_thread;
519 thread U x_thread[values_per_thread];
520 thread U result[results_per_simdgroup] = {0};
523 const int in_vec_size_w = in_vec_size / pack_factor;
524 const int in_vec_size_g = in_vec_size / group_size;
525 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
526 simd_gid * results_per_simdgroup;
527 const int used_out_row =
min(out_vec_size - results_per_simdgroup, out_row);
529 if (out_row >= out_vec_size) {
535 if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
536 w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
537 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
538 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
539 x += tid.y * in_vec_size + simd_lid * values_per_thread;
540 y += tid.y * out_vec_size + out_row;
543 for (; k < in_vec_size - block_size; k += block_size) {
546 for (
int row = 0; out_row + row < out_vec_size; row++) {
547 const device uint8_t* wl =
548 (
const device uint8_t*)(w + row * in_vec_size_w);
549 const device T* sl = scales + row * in_vec_size_g;
550 const device T* bl = biases + row * in_vec_size_g;
558 w += block_size / pack_factor;
559 scales += block_size / group_size;
560 biases += block_size / group_size;
563 const int remaining = clamp(
564 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
570 for (
int row = 0; out_row + row < out_vec_size; row++) {
571 const device uint8_t* wl =
572 (
const device uint8_t*)(w + row * in_vec_size_w);
573 const device T* sl = scales + row * in_vec_size_g;
574 const device T* bl = biases + row * in_vec_size_g;
581 for (
int row = 0; out_row + row < out_vec_size; row++) {
582 result[row] =
simd_sum(result[row]);
584 y[row] =
static_cast<T
>(result[row]);
591 w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread;
592 scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
593 biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
594 x += tid.y * in_vec_size + simd_lid * values_per_thread;
595 y += tid.y * out_vec_size + used_out_row;
598 for (; k < in_vec_size - block_size; k += block_size) {
601 for (
int row = 0; row < results_per_simdgroup; row++) {
602 const device uint8_t* wl =
603 (
const device uint8_t*)(w + row * in_vec_size_w);
604 const device T* sl = scales + row * in_vec_size_g;
605 const device T* bl = biases + row * in_vec_size_g;
613 w += block_size / pack_factor;
614 scales += block_size / group_size;
615 biases += block_size / group_size;
618 const int remaining = clamp(
619 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
625 for (
int row = 0; row < results_per_simdgroup; row++) {
626 const device uint8_t* wl =
627 (
const device uint8_t*)(w + row * in_vec_size_w);
628 const device T* sl = scales + row * in_vec_size_g;
629 const device T* bl = biases + row * in_vec_size_g;
634 wl, x_thread, s, b, sum, remaining);
637 for (
int row = 0; row < results_per_simdgroup; row++) {
638 result[row] =
simd_sum(result[row]);
640 y[row] =
static_cast<T
>(result[row]);
646template <
typename T, const
int group_size, const
int bits>
648 const device uint32_t* w,
649 const device T* scales,
650 const device T* biases,
653 const int in_vec_size,
654 const int out_vec_size,
655 uint3 tid [[threadgroup_position_in_grid]],
656 uint simd_gid [[simdgroup_index_in_threadgroup]],
657 uint simd_lid [[thread_index_in_simdgroup]]) {
658 constexpr int num_simdgroups = 2;
659 constexpr int pack_factor = 32 / bits;
660 constexpr int tn = 32 / pack_factor;
668 thread vec_w w_local;
669 thread U result[tn * pack_factor] = {0};
672 thread U x_local = 0;
675 const int out_vec_size_w = out_vec_size / pack_factor;
676 const int out_vec_size_g = out_vec_size / group_size;
678 tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
679 w += out_col / pack_factor + simd_lid * out_vec_size_w;
680 scales += out_col / group_size + simd_lid * out_vec_size_g;
681 biases += out_col / group_size + simd_lid * out_vec_size_g;
682 x += tid.y * in_vec_size + simd_lid;
683 y += tid.y * out_vec_size + out_col;
685 if (out_col >= out_vec_size) {
690 int remaining = in_vec_size % blocksize;
691 if (remaining == 0) {
692 for (
int i = 0; i < in_vec_size; i += blocksize) {
696 w_local = *((device vec_w*)w);
699 (thread uint8_t*)&w_local, x_local, scale, bias, result);
702 scales += blocksize * out_vec_size_g;
703 biases += blocksize * out_vec_size_g;
704 w += blocksize * out_vec_size_w;
707 for (
int i = blocksize; i < in_vec_size; i += blocksize) {
711 w_local = *((device vec_w*)w);
714 (thread uint8_t*)&w_local, x_local, scale, bias, result);
717 scales += blocksize * out_vec_size_g;
718 biases += blocksize * out_vec_size_g;
719 w += blocksize * out_vec_size_w;
721 if (
static_cast<int>(simd_lid) < remaining) {
725 w_local = *((device vec_w*)w);
732 (thread uint8_t*)&w_local, x_local, scale, bias, result);
736#pragma clang loop unroll(full)
737 for (
int k = 0; k < tn * pack_factor; k++) {
743#pragma clang loop unroll(full)
744 for (
int k = 0; k < tn * pack_factor; k++) {
745 y[k] =
static_cast<T
>(result[k]);
752 const int group_size,
754 const bool aligned_N,
759 const device uint32_t* w,
760 const device T* scales,
761 const device T* biases,
766 const constant
int& K,
767 const constant
int& N,
768 const constant
int& M,
769 uint3 tid [[threadgroup_position_in_grid]],
770 uint lid [[thread_index_in_threadgroup]],
771 uint simd_gid [[simdgroup_index_in_threadgroup]],
772 uint simd_lid [[thread_index_in_simdgroup]]) {
773 static_assert(BK >=
SIMD_SIZE,
"BK should be larger than SIMD_SIZE");
774 static_assert(BK %
SIMD_SIZE == 0,
"BK should be divisible by SIMD_SIZE");
778 constexpr int WM = 2;
779 constexpr int WN = 2;
780 constexpr int pack_factor = 32 / bits;
781 constexpr int BK_padded = (BK + 16 /
sizeof(T));
784 using mma_t = mlx::steel::
785 BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
799 const int K_w = K / pack_factor;
800 const int K_g = K / group_size;
801 const int y_row = tid.y * BM;
802 const int y_col = tid.x * BN;
806 scales += y_col * K_g;
807 biases += y_col * K_g;
808 y += y_row * N + y_col;
811 const short num_els =
min(BM, M - y_row);
812 const short num_outs =
min(BN, N - y_col);
813 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
814 loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid);
815 mma_t mma_op(simd_gid, simd_lid);
818 if (!aligned_N && num_outs < BN) {
819 for (
int k = 0; k < K; k += BK) {
820 threadgroup_barrier(mem_flags::mem_threadgroup);
821 loader_x.load_safe(short2(BK, num_els));
822 loader_w.load_safe(short2(BK, num_outs));
823 threadgroup_barrier(mem_flags::mem_threadgroup);
829 for (
int k = 0; k < K; k += BK) {
830 threadgroup_barrier(mem_flags::mem_threadgroup);
831 loader_x.load_safe(short2(BK, num_els));
832 loader_w.load_unsafe();
833 threadgroup_barrier(mem_flags::mem_threadgroup);
840 if (!aligned_N && num_outs < BN) {
841 for (
int k = 0; k < K; k += BK) {
842 threadgroup_barrier(mem_flags::mem_threadgroup);
843 loader_x.load_unsafe();
844 loader_w.load_safe(short2(BK, num_outs));
845 threadgroup_barrier(mem_flags::mem_threadgroup);
851 for (
int k = 0; k < K; k += BK) {
852 threadgroup_barrier(mem_flags::mem_threadgroup);
853 loader_x.load_unsafe();
854 loader_w.load_unsafe();
855 threadgroup_barrier(mem_flags::mem_threadgroup);
864 threadgroup_barrier(mem_flags::mem_threadgroup);
865 if (num_els < BM || num_outs < BN) {
866 mma_op.store_result_safe(y, N, short2(num_outs, num_els));
868 mma_op.store_result(y, N);
874 const int group_size,
880 const device uint32_t* w,
881 const device T* scales,
882 const device T* biases,
887 const constant
int& K,
888 const constant
int& N,
889 const constant
int& M,
890 uint3 tid [[threadgroup_position_in_grid]],
891 uint lid [[thread_index_in_threadgroup]],
892 uint simd_gid [[simdgroup_index_in_threadgroup]],
893 uint simd_lid [[thread_index_in_simdgroup]]) {
894 static_assert(BK >=
SIMD_SIZE,
"BK should be larger than SIMD_SIZE");
895 static_assert(BK %
SIMD_SIZE == 0,
"BK should be divisible by SIMD_SIZE");
899 constexpr int WM = 2;
900 constexpr int WN = 2;
901 constexpr int pack_factor = 32 / bits;
902 constexpr int BK_padded = (BK + 16 /
sizeof(T));
903 constexpr int BN_padded = (BN + 16 /
sizeof(T));
906 using mma_t = mlx::steel::
907 BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
908 using loader_x_t = mlx::steel::
909 BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
921 const int y_row = tid.y * BM;
922 const int y_col = tid.x * BN;
924 w += y_col / pack_factor;
925 scales += y_col / group_size;
926 biases += y_col / group_size;
927 y += y_row * N + y_col;
930 const short num_els =
min(BM, M - y_row);
931 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
932 loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid);
933 mma_t mma_op(simd_gid, simd_lid);
937 const int k_blocks = K / BK;
938 for (
int k = 0; k < k_blocks; k++) {
939 threadgroup_barrier(mem_flags::mem_threadgroup);
940 loader_x.load_safe(short2(BK, num_els));
941 loader_w.load_unsafe();
942 threadgroup_barrier(mem_flags::mem_threadgroup);
947 const short num_k = K - k_blocks * BK;
948 threadgroup_barrier(mem_flags::mem_threadgroup);
949 loader_x.load_safe(short2(num_k, num_els));
950 loader_w.load_safe(short2(BN, num_k));
951 threadgroup_barrier(mem_flags::mem_threadgroup);
954 for (
int k = 0; k < K; k += BK) {
955 threadgroup_barrier(mem_flags::mem_threadgroup);
956 loader_x.load_safe(short2(BK, num_els));
957 loader_w.load_unsafe();
958 threadgroup_barrier(mem_flags::mem_threadgroup);
966 const int k_blocks = K / BK;
967 for (
int k = 0; k < k_blocks; k++) {
968 threadgroup_barrier(mem_flags::mem_threadgroup);
969 loader_x.load_unsafe();
970 loader_w.load_unsafe();
971 threadgroup_barrier(mem_flags::mem_threadgroup);
976 const short num_k = K - k_blocks * BK;
977 threadgroup_barrier(mem_flags::mem_threadgroup);
978 loader_x.load_safe(short2(num_k, BM));
979 loader_w.load_safe(short2(BN, num_k));
980 threadgroup_barrier(mem_flags::mem_threadgroup);
983 for (
int k = 0; k < K; k += BK) {
984 threadgroup_barrier(mem_flags::mem_threadgroup);
985 loader_x.load_unsafe();
986 loader_w.load_unsafe();
987 threadgroup_barrier(mem_flags::mem_threadgroup);
996 threadgroup_barrier(mem_flags::mem_threadgroup);
998 mma_op.store_result_safe(y, N, short2(BN, num_els));
1000 mma_op.store_result(y, N);
1004template <
typename T>
1007 const device uint32_t*& w,
1008 const device T*& scales,
1009 const device T*& biases,
1012 const constant
int& x_batch_ndims,
1013 const constant
int* x_shape,
1014 const constant
size_t* x_strides,
1015 const constant
int& w_batch_ndims,
1016 const constant
int* w_shape,
1017 const constant
size_t* w_strides,
1018 const constant
size_t* s_strides,
1019 const constant
size_t* b_strides,
1020 uint3 tid [[threadgroup_position_in_grid]]) {
1022 uint32_t x_idx = tid.z;
1023 uint32_t w_idx = tid.z;
1024 if (x_batch_ndims == 1) {
1025 x += x_idx * x_strides[0];
1027 x +=
elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1029 if (w_batch_ndims == 1) {
1030 w += w_idx * w_strides[0];
1031 scales += w_idx * s_strides[0];
1032 biases += w_idx * b_strides[0];
1035 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1040 y += tid.z * output_stride;
1043template <
typename T>
1046 const device uint32_t*& w,
1047 const device T*& scales,
1048 const device T*& biases,
1049 const device uint32_t* lhs_indices,
1050 const device uint32_t* rhs_indices,
1053 const constant
int& batch_ndims,
1054 const constant
int* batch_shape,
1055 const constant
size_t* lhs_strides,
1056 const constant
size_t* rhs_strides,
1057 const constant
int& x_batch_ndims,
1058 const constant
int* x_shape,
1059 const constant
size_t* x_strides,
1060 const constant
int& w_batch_ndims,
1061 const constant
int* w_shape,
1062 const constant
size_t* w_strides,
1063 const constant
size_t* s_strides,
1064 const constant
size_t* b_strides,
1065 uint3 tid [[threadgroup_position_in_grid]]) {
1069 if (batch_ndims == 1) {
1070 x_idx = lhs_indices[tid.z * lhs_strides[0]];
1071 w_idx = rhs_indices[tid.z * rhs_strides[0]];
1074 tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
1075 x_idx = lhs_indices[idx.x];
1076 w_idx = rhs_indices[idx.y];
1078 if (x_batch_ndims == 1) {
1079 x += x_idx * x_strides[0];
1081 x +=
elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1083 if (w_batch_ndims == 1) {
1084 w += w_idx * w_strides[0];
1085 scales += w_idx * s_strides[0];
1086 biases += w_idx * b_strides[0];
1089 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1094 y += tid.z * output_stride;
1097template <
typename T,
int group_size,
int bits,
int D,
bool batched>
1099 const device uint32_t* w [[buffer(0)]],
1100 const device T* scales [[buffer(1)]],
1101 const device T* biases [[buffer(2)]],
1102 const device T* x [[buffer(3)]],
1103 device T* y [[buffer(4)]],
1104 const constant
int& in_vec_size [[buffer(5)]],
1105 const constant
int& out_vec_size [[buffer(6)]],
1106 const constant
int& x_batch_ndims [[buffer(7)]],
1107 const constant
int* x_shape [[buffer(8)]],
1108 const constant
size_t* x_strides [[buffer(9)]],
1109 const constant
int& w_batch_ndims [[buffer(10)]],
1110 const constant
int* w_shape [[buffer(11)]],
1111 const constant
size_t* w_strides [[buffer(12)]],
1112 const constant
size_t* s_strides [[buffer(13)]],
1113 const constant
size_t* b_strides [[buffer(14)]],
1114 uint3 tid [[threadgroup_position_in_grid]],
1115 uint quad_gid [[quadgroup_index_in_threadgroup]],
1116 uint quad_lid [[thread_index_in_quadgroup]]) {
1148template <
typename T,
int group_size,
int bits,
bool batched>
1150 const device uint32_t* w [[buffer(0)]],
1151 const device T* scales [[buffer(1)]],
1152 const device T* biases [[buffer(2)]],
1153 const device T* x [[buffer(3)]],
1154 device T* y [[buffer(4)]],
1155 const constant
int& in_vec_size [[buffer(5)]],
1156 const constant
int& out_vec_size [[buffer(6)]],
1157 const constant
int& x_batch_ndims [[buffer(7)]],
1158 const constant
int* x_shape [[buffer(8)]],
1159 const constant
size_t* x_strides [[buffer(9)]],
1160 const constant
int& w_batch_ndims [[buffer(10)]],
1161 const constant
int* w_shape [[buffer(11)]],
1162 const constant
size_t* w_strides [[buffer(12)]],
1163 const constant
size_t* s_strides [[buffer(13)]],
1164 const constant
size_t* b_strides [[buffer(14)]],
1165 uint3 tid [[threadgroup_position_in_grid]],
1166 uint simd_gid [[simdgroup_index_in_threadgroup]],
1167 uint simd_lid [[thread_index_in_simdgroup]]) {
1199template <
typename T, const
int group_size, const
int bits,
bool batched>
1201 const device uint32_t* w [[buffer(0)]],
1202 const device T* scales [[buffer(1)]],
1203 const device T* biases [[buffer(2)]],
1204 const device T* x [[buffer(3)]],
1205 device T* y [[buffer(4)]],
1206 const constant
int& in_vec_size [[buffer(5)]],
1207 const constant
int& out_vec_size [[buffer(6)]],
1208 const constant
int& x_batch_ndims [[buffer(7)]],
1209 const constant
int* x_shape [[buffer(8)]],
1210 const constant
size_t* x_strides [[buffer(9)]],
1211 const constant
int& w_batch_ndims [[buffer(10)]],
1212 const constant
int* w_shape [[buffer(11)]],
1213 const constant
size_t* w_strides [[buffer(12)]],
1214 const constant
size_t* s_strides [[buffer(13)]],
1215 const constant
size_t* b_strides [[buffer(14)]],
1216 uint3 tid [[threadgroup_position_in_grid]],
1217 uint simd_gid [[simdgroup_index_in_threadgroup]],
1218 uint simd_lid [[thread_index_in_simdgroup]]) {
1250template <
typename T, const
int group_size, const
int bits,
bool batched>
1252 const device uint32_t* w [[buffer(0)]],
1253 const device T* scales [[buffer(1)]],
1254 const device T* biases [[buffer(2)]],
1255 const device T* x [[buffer(3)]],
1256 device T* y [[buffer(4)]],
1257 const constant
int& in_vec_size [[buffer(5)]],
1258 const constant
int& out_vec_size [[buffer(6)]],
1259 const constant
int& x_batch_ndims [[buffer(7)]],
1260 const constant
int* x_shape [[buffer(8)]],
1261 const constant
size_t* x_strides [[buffer(9)]],
1262 const constant
int& w_batch_ndims [[buffer(10)]],
1263 const constant
int* w_shape [[buffer(11)]],
1264 const constant
size_t* w_strides [[buffer(12)]],
1265 const constant
size_t* s_strides [[buffer(13)]],
1266 const constant
size_t* b_strides [[buffer(14)]],
1267 uint3 tid [[threadgroup_position_in_grid]],
1268 uint simd_gid [[simdgroup_index_in_threadgroup]],
1269 uint simd_lid [[thread_index_in_simdgroup]]) {
1301template <
typename T, const
int group_size, const
int bits,
int split_k = 32>
1303 const device uint32_t* w [[buffer(0)]],
1304 const device T* scales [[buffer(1)]],
1305 const device T* biases [[buffer(2)]],
1306 const device T* x [[buffer(3)]],
1307 device T* y [[buffer(4)]],
1308 const constant
int& in_vec_size [[buffer(5)]],
1309 const constant
int& out_vec_size [[buffer(6)]],
1310 const constant
int& x_batch_ndims [[buffer(7)]],
1311 const constant
int* x_shape [[buffer(8)]],
1312 const constant
size_t* x_strides [[buffer(9)]],
1313 const constant
int& w_batch_ndims [[buffer(10)]],
1314 const constant
int* w_shape [[buffer(11)]],
1315 const constant
size_t* w_strides [[buffer(12)]],
1316 const constant
size_t* s_strides [[buffer(13)]],
1317 const constant
size_t* b_strides [[buffer(14)]],
1318 const constant
int& final_block_size [[buffer(15)]],
1319 uint3 tid [[threadgroup_position_in_grid]],
1320 uint simd_gid [[simdgroup_index_in_threadgroup]],
1321 uint simd_lid [[thread_index_in_simdgroup]]) {
1340 int in_vec_size_adj =
1341 tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
1358 const int group_size,
1360 const bool aligned_N,
1366 const device uint32_t* w [[buffer(0)]],
1367 const device T* scales [[buffer(1)]],
1368 const device T* biases [[buffer(2)]],
1369 const device T* x [[buffer(3)]],
1370 device T* y [[buffer(4)]],
1371 const constant
int& K [[buffer(5)]],
1372 const constant
int& N [[buffer(6)]],
1373 const constant
int& M [[buffer(7)]],
1374 const constant
int& x_batch_ndims [[buffer(8)]],
1375 const constant
int* x_shape [[buffer(9)]],
1376 const constant
size_t* x_strides [[buffer(10)]],
1377 const constant
int& w_batch_ndims [[buffer(11)]],
1378 const constant
int* w_shape [[buffer(12)]],
1379 const constant
size_t* w_strides [[buffer(13)]],
1380 const constant
size_t* s_strides [[buffer(14)]],
1381 const constant
size_t* b_strides [[buffer(15)]],
1382 uint3 tid [[threadgroup_position_in_grid]],
1383 uint lid [[thread_index_in_threadgroup]],
1384 uint simd_gid [[simdgroup_index_in_threadgroup]],
1385 uint simd_lid [[thread_index_in_simdgroup]]) {
1388 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1390 threadgroup T Xs[BM * BK_padded];
1391 threadgroup T Ws[BN * BK_padded];
1412 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1417 const int group_size,
1424 const device uint32_t* w [[buffer(0)]],
1425 const device T* scales [[buffer(1)]],
1426 const device T* biases [[buffer(2)]],
1427 const device T* x [[buffer(3)]],
1428 device T* y [[buffer(4)]],
1429 const constant
int& K [[buffer(5)]],
1430 const constant
int& N [[buffer(6)]],
1431 const constant
int& M [[buffer(7)]],
1432 const constant
int& x_batch_ndims [[buffer(8)]],
1433 const constant
int* x_shape [[buffer(9)]],
1434 const constant
size_t* x_strides [[buffer(10)]],
1435 const constant
int& w_batch_ndims [[buffer(11)]],
1436 const constant
int* w_shape [[buffer(12)]],
1437 const constant
size_t* w_strides [[buffer(13)]],
1438 const constant
size_t* s_strides [[buffer(14)]],
1439 const constant
size_t* b_strides [[buffer(15)]],
1440 uint3 tid [[threadgroup_position_in_grid]],
1441 uint lid [[thread_index_in_threadgroup]],
1442 uint simd_gid [[simdgroup_index_in_threadgroup]],
1443 uint simd_lid [[thread_index_in_simdgroup]]) {
1446 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1447 constexpr int BN_padded = (BN + 16 /
sizeof(T));
1449 threadgroup T Xs[BM * BK_padded];
1450 threadgroup T Ws[BK * BN_padded];
1472 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1475template <
typename T,
int group_size,
int bits>
1477 const device uint32_t* w [[buffer(0)]],
1478 const device T* scales [[buffer(1)]],
1479 const device T* biases [[buffer(2)]],
1480 const device T* x [[buffer(3)]],
1481 device T* y [[buffer(4)]],
1482 const constant
int& in_vec_size [[buffer(5)]],
1483 const constant
int& out_vec_size [[buffer(6)]],
1484 const constant
int& x_batch_ndims [[buffer(7)]],
1485 const constant
int* x_shape [[buffer(8)]],
1486 const constant
size_t* x_strides [[buffer(9)]],
1487 const constant
int& w_batch_ndims [[buffer(10)]],
1488 const constant
int* w_shape [[buffer(11)]],
1489 const constant
size_t* w_strides [[buffer(12)]],
1490 const constant
size_t* s_strides [[buffer(13)]],
1491 const constant
size_t* b_strides [[buffer(14)]],
1492 const constant
int& batch_ndims [[buffer(15)]],
1493 const constant
int* batch_shape [[buffer(16)]],
1494 const device uint32_t* lhs_indices [[buffer(17)]],
1495 const device uint32_t* rhs_indices [[buffer(18)]],
1496 const constant
size_t* lhs_strides [[buffer(19)]],
1497 const constant
size_t* rhs_strides [[buffer(20)]],
1498 uint3 tid [[threadgroup_position_in_grid]],
1499 uint simd_gid [[simdgroup_index_in_threadgroup]],
1500 uint simd_lid [[thread_index_in_simdgroup]]) {
1536template <
typename T,
int group_size,
int bits>
1538 const device uint32_t* w [[buffer(0)]],
1539 const device T* scales [[buffer(1)]],
1540 const device T* biases [[buffer(2)]],
1541 const device T* x [[buffer(3)]],
1542 device T* y [[buffer(4)]],
1543 const constant
int& in_vec_size [[buffer(5)]],
1544 const constant
int& out_vec_size [[buffer(6)]],
1545 const constant
int& x_batch_ndims [[buffer(7)]],
1546 const constant
int* x_shape [[buffer(8)]],
1547 const constant
size_t* x_strides [[buffer(9)]],
1548 const constant
int& w_batch_ndims [[buffer(10)]],
1549 const constant
int* w_shape [[buffer(11)]],
1550 const constant
size_t* w_strides [[buffer(12)]],
1551 const constant
size_t* s_strides [[buffer(13)]],
1552 const constant
size_t* b_strides [[buffer(14)]],
1553 const constant
int& batch_ndims [[buffer(15)]],
1554 const constant
int* batch_shape [[buffer(16)]],
1555 const device uint32_t* lhs_indices [[buffer(17)]],
1556 const device uint32_t* rhs_indices [[buffer(18)]],
1557 const constant
size_t* lhs_strides [[buffer(19)]],
1558 const constant
size_t* rhs_strides [[buffer(20)]],
1559 uint3 tid [[threadgroup_position_in_grid]],
1560 uint simd_gid [[simdgroup_index_in_threadgroup]],
1561 uint simd_lid [[thread_index_in_simdgroup]]) {
1597template <
typename T,
int group_size,
int bits>
1599 const device uint32_t* w [[buffer(0)]],
1600 const device T* scales [[buffer(1)]],
1601 const device T* biases [[buffer(2)]],
1602 const device T* x [[buffer(3)]],
1603 device T* y [[buffer(4)]],
1604 const constant
int& in_vec_size [[buffer(5)]],
1605 const constant
int& out_vec_size [[buffer(6)]],
1606 const constant
int& x_batch_ndims [[buffer(7)]],
1607 const constant
int* x_shape [[buffer(8)]],
1608 const constant
size_t* x_strides [[buffer(9)]],
1609 const constant
int& w_batch_ndims [[buffer(10)]],
1610 const constant
int* w_shape [[buffer(11)]],
1611 const constant
size_t* w_strides [[buffer(12)]],
1612 const constant
size_t* s_strides [[buffer(13)]],
1613 const constant
size_t* b_strides [[buffer(14)]],
1614 const constant
int& batch_ndims [[buffer(15)]],
1615 const constant
int* batch_shape [[buffer(16)]],
1616 const device uint32_t* lhs_indices [[buffer(17)]],
1617 const device uint32_t* rhs_indices [[buffer(18)]],
1618 const constant
size_t* lhs_strides [[buffer(19)]],
1619 const constant
size_t* rhs_strides [[buffer(20)]],
1620 uint3 tid [[threadgroup_position_in_grid]],
1621 uint simd_gid [[simdgroup_index_in_threadgroup]],
1622 uint simd_lid [[thread_index_in_simdgroup]]) {
1660 const int group_size,
1662 const bool aligned_N,
1667 const device uint32_t* w [[buffer(0)]],
1668 const device T* scales [[buffer(1)]],
1669 const device T* biases [[buffer(2)]],
1670 const device T* x [[buffer(3)]],
1671 device T* y [[buffer(4)]],
1672 const constant
int& K [[buffer(5)]],
1673 const constant
int& N [[buffer(6)]],
1674 const constant
int& M [[buffer(7)]],
1675 const constant
int& x_batch_ndims [[buffer(8)]],
1676 const constant
int* x_shape [[buffer(9)]],
1677 const constant
size_t* x_strides [[buffer(10)]],
1678 const constant
int& w_batch_ndims [[buffer(11)]],
1679 const constant
int* w_shape [[buffer(12)]],
1680 const constant
size_t* w_strides [[buffer(13)]],
1681 const constant
size_t* s_strides [[buffer(14)]],
1682 const constant
size_t* b_strides [[buffer(15)]],
1683 const constant
int& batch_ndims [[buffer(16)]],
1684 const constant
int* batch_shape [[buffer(17)]],
1685 const device uint32_t* lhs_indices [[buffer(18)]],
1686 const device uint32_t* rhs_indices [[buffer(19)]],
1687 const constant
size_t* lhs_strides [[buffer(20)]],
1688 const constant
size_t* rhs_strides [[buffer(21)]],
1689 uint3 tid [[threadgroup_position_in_grid]],
1690 uint lid [[thread_index_in_threadgroup]],
1691 uint simd_gid [[simdgroup_index_in_threadgroup]],
1692 uint simd_lid [[thread_index_in_simdgroup]]) {
1695 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1697 threadgroup T Xs[BM * BK_padded];
1698 threadgroup T Ws[BN * BK_padded];
1723 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1728 const int group_size,
1734 const device uint32_t* w [[buffer(0)]],
1735 const device T* scales [[buffer(1)]],
1736 const device T* biases [[buffer(2)]],
1737 const device T* x [[buffer(3)]],
1738 device T* y [[buffer(4)]],
1739 const constant
int& K [[buffer(5)]],
1740 const constant
int& N [[buffer(6)]],
1741 const constant
int& M [[buffer(7)]],
1742 const constant
int& x_batch_ndims [[buffer(8)]],
1743 const constant
int* x_shape [[buffer(9)]],
1744 const constant
size_t* x_strides [[buffer(10)]],
1745 const constant
int& w_batch_ndims [[buffer(11)]],
1746 const constant
int* w_shape [[buffer(12)]],
1747 const constant
size_t* w_strides [[buffer(13)]],
1748 const constant
size_t* s_strides [[buffer(14)]],
1749 const constant
size_t* b_strides [[buffer(15)]],
1750 const constant
int& batch_ndims [[buffer(16)]],
1751 const constant
int* batch_shape [[buffer(17)]],
1752 const device uint32_t* lhs_indices [[buffer(18)]],
1753 const device uint32_t* rhs_indices [[buffer(19)]],
1754 const constant
size_t* lhs_strides [[buffer(20)]],
1755 const constant
size_t* rhs_strides [[buffer(21)]],
1756 uint3 tid [[threadgroup_position_in_grid]],
1757 uint lid [[thread_index_in_threadgroup]],
1758 uint simd_gid [[simdgroup_index_in_threadgroup]],
1759 uint simd_lid [[thread_index_in_simdgroup]]) {
1762 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1763 constexpr int BN_padded = (BN + 16 /
sizeof(T));
1765 threadgroup T Xs[BM * BK_padded];
1766 threadgroup T Ws[BK * BN_padded];
1791 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1794template <
typename T, const
int group_size, const
int bits>
1796 const device T* w [[buffer(0)]],
1797 device uint8_t* out [[buffer(1)]],
1798 device T* scales [[buffer(2)]],
1799 device T* biases [[buffer(3)]],
1800 uint2 index [[thread_position_in_grid]],
1801 uint2 grid_dim [[threads_per_grid]]) {
1802 constexpr T eps = T(1e-7);
1804 constexpr int uint8_bits = 8;
1805 constexpr T n_bins = (1 << bits) - 1;
1806 constexpr int packs_per_int = uint8_bits / bits;
1807 constexpr int values_per_reduce = group_size /
simd_size;
1808 constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
1809 constexpr int writes_per_pack =
1810 writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
1814 "Group size must be divisible by simd size.");
1816 size_t offset = index.x + grid_dim.x * size_t(index.y);
1817 size_t in_index = offset * values_per_reduce;
1818 size_t out_index = offset * writes_per_pack;
1820 T w_thread[values_per_reduce];
1824#pragma clang loop unroll(full)
1825 for (
int i = 0; i < values_per_reduce; i++) {
1826 T val = w[in_index + i];
1828 w_min =
min(w_min, val);
1829 w_max =
max(w_max, val);
1835 T scale =
max((w_max - w_min) / n_bins, eps);
1836 bool side =
abs(w_min) >
abs(w_max);
1837 scale = side ? scale : -scale;
1838 T edge = side ? w_min : w_max;
1839 T q0 =
round(edge / scale);
1840 bool at_zero = q0 == 0.0f;
1841 scale = at_zero ? scale : edge / q0;
1842 T bias = at_zero ? T(0) : edge;
1845 size_t gindex = in_index / group_size;
1846 if (in_index % group_size == 0) {
1847 scales[gindex] = scale;
1848 biases[gindex] = bias;
1852#pragma clang loop unroll(full)
1853 for (
int i = 0; i < values_per_reduce; i++) {
1854 uint8_t val =
min(
round((w_thread[i] - bias) / scale), n_bins);
1858 output += val << (bits * (i % packs_per_int));
1861 if (packs_per_int < values_per_reduce &&
1862 i % packs_per_int == packs_per_int - 1) {
1863 out[out_index + i / packs_per_int] = output;
1866#pragma clang loop unroll(full)
1867 for (
int j = 0; j < writes_per_reduce - 1; j++) {
1869 output += sval << (bits * (values_per_reduce + j + i));
1873 if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
1874 out[out_index / writes_per_reduce] = output;
1878template <
typename T, const
int group_size, const
int bits>
1880 const device T* w [[buffer(0)]],
1881 const device T* scales [[buffer(1)]],
1882 const device T* biases [[buffer(2)]],
1883 device uint8_t* out [[buffer(3)]],
1884 uint2 index [[thread_position_in_grid]],
1885 uint2 grid_dim [[threads_per_grid]]) {
1886 constexpr int uint8_bits = 8;
1887 constexpr int packs_per_int = uint8_bits / bits;
1888 constexpr T n_bins = (1 << bits) - 1;
1890 size_t offset = index.x + grid_dim.x * size_t(index.y);
1891 size_t in_index = offset * packs_per_int;
1892 size_t gindex = in_index / group_size;
1894 T scale = scales[gindex];
1895 T bias = biases[gindex];
1898#pragma clang loop unroll(full)
1899 for (
int i = 0; i < packs_per_int; i++) {
1900 uint8_t val =
min(
round((w[in_index + i] - bias) / scale), n_bins);
1904 output += val << (bits * i);
1907 out[offset] = output;
1910template <
typename T, const
int group_size, const
int bits>
1912 const device uint8_t* w [[buffer(0)]],
1913 const device T* scales [[buffer(1)]],
1914 const device T* biases [[buffer(2)]],
1915 device T* out [[buffer(3)]],
1916 uint2 index [[thread_position_in_grid]],
1917 uint2 grid_dim [[threads_per_grid]]) {
1918 constexpr int uint8_bits = 8;
1919 constexpr int packs_per_int = uint8_bits / bits;
1921 size_t offset = index.x + grid_dim.x * size_t(index.y);
1922 size_t oindex = offset * packs_per_int;
1923 size_t gindex = oindex / group_size;
1924 T scale = scales[gindex];
1925 T bias = biases[gindex];
1926 uint val = w[offset];
1928#pragma clang loop unroll(full)
1929 for (
int i = 0; i < packs_per_int; i++) {
1932 d = (val >> (bits * i)) & 0x03;
1933 }
else if (bits == 4) {
1934 d = (val >> (bits * i)) & 0x0f;
1935 }
else if (bits == 8) {
1938 out[oindex + i] = scale * d + bias;
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
#define MLX_MTL_CONST
Definition quantized.h:8
U qdot_safe(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum, int N)
Definition quantized.h:142
METAL_FUNC void qmm_n_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:879
METAL_FUNC void qvm_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const int in_vec_size, const int out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:647
void bs_qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1733
void qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1423
void affine_quantize(const device T *w, device uint8_t *out, device T *scales, device T *biases, uint2 index, uint2 grid_dim)
Definition quantized.h:1795
void bs_qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1476
void affine_dequantize(const device uint8_t *w, const device T *scales, const device T *biases, device T *out, uint2 index, uint2 grid_dim)
Definition quantized.h:1911
static constant constexpr const int SIMD_SIZE
Definition quantized.h:10
void qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1200
void bs_qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1598
void affine_quantize_scales_biases(const device T *w, const device T *scales, const device T *biases, device uint8_t *out, uint2 index, uint2 grid_dim)
Definition quantized.h:1879
void qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1149
void qmv_quad(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:1098
static constant constexpr const int QUAD_SIZE
Definition quantized.h:11
U load_vector(const device T *x, thread U *x_thread)
Definition quantized.h:14
METAL_FUNC void qmv_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:498
U load_vector_safe(const device T *x, thread U *x_thread, int N)
Definition quantized.h:52
void bs_qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1666
U qdot(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum)
Definition quantized.h:99
void qvm_split_k(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &final_block_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1302
METAL_FUNC void qmv_fast_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:434
void qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1365
METAL_FUNC void adjust_matrix_offsets(const device T *&x, const device uint32_t *&w, const device T *&scales, const device T *&biases, device T *&y, int output_stride, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid)
Definition quantized.h:1005
void bs_qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1537
METAL_FUNC void qmv_quad_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:376
void qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1251
void qouter(const thread uint8_t *w, U x, U scale, U bias, thread U *result)
Definition quantized.h:187
void dequantize(const device uint8_t *w, U scale, U bias, threadgroup U *w_local)
Definition quantized.h:219
METAL_FUNC void qmm_t_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:758
Definition quantized.h:262
const int group_stride
Definition quantized.h:282
static constant constexpr const short BCOLS_PACKED
Definition quantized.h:274
const device T * biases
Definition quantized.h:291
short group_step_cnt
Definition quantized.h:281
static constant constexpr const short group_steps
Definition quantized.h:277
const short thread_idx
Definition quantized.h:284
const device T * scales
Definition quantized.h:290
static constant constexpr const short n_reads
Definition quantized.h:275
void next()
Definition quantized.h:354
void load_safe(short2 src_tile_dim) const
Definition quantized.h:327
const int src_ld
Definition quantized.h:279
const short bi
Definition quantized.h:285
void load_unsafe() const
Definition quantized.h:314
static constant constexpr const short pack_factor
Definition quantized.h:273
threadgroup T * dst
Definition quantized.h:288
const int tile_stride
Definition quantized.h:280
const device uint32_t * src
Definition quantized.h:289
const short bj
Definition quantized.h:286
QuantizedBlockLoader(const device uint32_t *src_, const device T *scales_, const device T *biases_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition quantized.h:293