3#include <metal_simdgroup>
8#define MLX_MTL_CONST static constant constexpr const
13template <
typename T,
typename U,
int values_per_thread,
int bits>
16 bits == 2 || bits == 4 || bits == 8,
17 "Template undefined for bits not in {2, 4, 8}");
22 for (
int i = 0; i < values_per_thread; i += 4) {
23 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
25 x_thread[i + 1] = x[i + 1] / 4.0f;
26 x_thread[i + 2] = x[i + 2] / 16.0f;
27 x_thread[i + 3] = x[i + 3] / 64.0f;
32 for (
int i = 0; i < values_per_thread; i += 4) {
33 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
35 x_thread[i + 1] = x[i + 1] / 16.0f;
36 x_thread[i + 2] = x[i + 2] / 256.0f;
37 x_thread[i + 3] = x[i + 3] / 4096.0f;
42 for (
int i = 0; i < values_per_thread; i++) {
51template <
typename T,
typename U,
int values_per_thread,
int bits>
54 bits == 2 || bits == 4 || bits == 8,
55 "Template undefined for bits not in {2, 4, 8}");
60 for (
int i = 0; i < N; i += 4) {
61 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
63 x_thread[i + 1] = x[i + 1] / 4.0f;
64 x_thread[i + 2] = x[i + 2] / 16.0f;
65 x_thread[i + 3] = x[i + 3] / 64.0f;
67 for (
int i = N; i < values_per_thread; i++) {
73 for (
int i = 0; i < N; i += 4) {
74 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
76 x_thread[i + 1] = x[i + 1] / 16.0f;
77 x_thread[i + 2] = x[i + 2] / 256.0f;
78 x_thread[i + 3] = x[i + 3] / 4096.0f;
80 for (
int i = N; i < values_per_thread; i++) {
86 for (
int i = 0; i < N; i++) {
90 for (
int i = N; i < values_per_thread; i++) {
98template <
typename U,
int values_per_thread,
int bits>
100 const device uint8_t* w,
101 const thread U* x_thread,
106 bits == 2 || bits == 4 || bits == 8,
107 "Template undefined for bits not in {2, 4, 8}");
112 for (
int i = 0; i < (values_per_thread / 4); i++) {
114 (x_thread[4 * i] * (w[i] & 0x03) +
115 x_thread[4 * i + 1] * (w[i] & 0x0c) +
116 x_thread[4 * i + 2] * (w[i] & 0x30) +
117 x_thread[4 * i + 3] * (w[i] & 0xc0));
121 else if (bits == 4) {
122 const device uint16_t* ws = (
const device uint16_t*)w;
123 for (
int i = 0; i < (values_per_thread / 4); i++) {
125 (x_thread[4 * i] * (ws[i] & 0x000f) +
126 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
127 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
128 x_thread[4 * i + 3] * (ws[i] & 0xf000));
132 else if (bits == 8) {
133 for (
int i = 0; i < values_per_thread; i++) {
134 accum += x_thread[i] * w[i];
138 return scale * accum + sum * bias;
141template <
typename U,
int values_per_thread,
int bits>
143 const device uint8_t* w,
144 const thread U* x_thread,
150 bits == 2 || bits == 4 || bits == 8,
151 "Template undefined for bits not in {2, 4, 8}");
156 for (
int i = 0; i < (N / 4); i++) {
158 (x_thread[4 * i] * (w[i] & 0x03) +
159 x_thread[4 * i + 1] * (w[i] & 0x0c) +
160 x_thread[4 * i + 2] * (w[i] & 0x30) +
161 x_thread[4 * i + 3] * (w[i] & 0xc0));
165 else if (bits == 4) {
166 const device uint16_t* ws = (
const device uint16_t*)w;
167 for (
int i = 0; i < (N / 4); i++) {
169 (x_thread[4 * i] * (ws[i] & 0x000f) +
170 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
171 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
172 x_thread[4 * i + 3] * (ws[i] & 0xf000));
176 else if (bits == 8) {
177 for (
int i = 0; i < N; i++) {
178 accum += x_thread[i] * w[i];
182 return scale * accum + sum * bias;
185template <
typename U,
int values_per_thread,
int bits>
187qouter(
const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
189 bits == 2 || bits == 4 || bits == 8,
190 "Template undefined for bits not in {2, 4, 8}");
193 U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
194 for (
int i = 0; i < (values_per_thread / 4); i++) {
195 result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
196 result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
197 result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
198 result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
202 else if (bits == 4) {
203 U s[2] = {scale, scale / 16.0f};
204 for (
int i = 0; i < (values_per_thread / 2); i++) {
205 result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
206 result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
210 else if (bits == 8) {
211 for (
int i = 0; i < values_per_thread; i++) {
212 result[i] += x * (scale * w[i] + bias);
217template <
typename U,
int N,
int bits>
219dequantize(
const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
221 bits == 2 || bits == 4 || bits == 8,
222 "Template undefined for bits not in {2, 4, 8}");
227 scale /
static_cast<U
>(4.0f),
228 scale /
static_cast<U
>(16.0f),
229 scale /
static_cast<U
>(64.0f)};
230 for (
int i = 0; i < (N / 4); i++) {
231 w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
232 w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
233 w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
234 w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
238 else if (bits == 4) {
239 U s[2] = {scale, scale /
static_cast<U
>(16.0f)};
240 for (
int i = 0; i < (N / 2); i++) {
241 w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
242 w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
246 else if (bits == 8) {
247 for (
int i = 0; i < N; i++) {
248 w_local[i] = scale * w[i] + bias;
265 "The group size should be larger than the columns");
267 group_size % BCOLS == 0,
268 "The group size should be divisible by the columns");
270 bits == 2 || bits == 4 || bits == 8,
271 "Template undefined for bits not in {2, 4, 8}");
289 const device uint32_t*
src;
294 const device uint32_t* src_,
295 const device T* scales_,
296 const device T* biases_,
299 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
300 ushort simd_lane_id [[thread_index_in_simdgroup]])
306 thread_idx(simd_group_id * 32 + simd_lane_id),
321 for (
int i = 0; i <
n_reads; i++) {
332 if (reduction_dim == 1 &&
bi >= src_tile_dim.y) {
339 if (reduction_dim == 0 &&
bi >= src_tile_dim.x) {
348 for (
int i = 0; i <
n_reads; i++) {
356 if (reduction_dim == 1) {
375template <
typename T,
int group_size,
int bits,
int D>
377 const device uint32_t* w,
378 const device T* scales,
379 const device T* biases,
382 constant
int& in_vec_size,
383 const constant
int& out_vec_size,
384 uint3 tid [[threadgroup_position_in_grid]],
385 uint quad_gid [[quadgroup_index_in_threadgroup]],
386 uint quad_lid [[thread_index_in_quadgroup]]) {
388 constexpr int pack_factor = 32 / bits;
389 constexpr int values_per_thread = D /
QUAD_SIZE;
390 constexpr int packs_per_thread = values_per_thread / pack_factor;
391 constexpr int scale_step_per_thread = group_size / values_per_thread;
392 constexpr int results_per_quadgroup = 8;
396 thread U x_thread[values_per_thread];
397 thread U result[results_per_quadgroup] = {0};
400 const int in_vec_size_w = in_vec_size / pack_factor;
401 const int in_vec_size_g = in_vec_size / group_size;
402 const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
404 w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
405 scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
406 biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
407 x += tid.y * in_vec_size + quad_lid * values_per_thread;
408 y += tid.y * out_vec_size + out_row;
412 for (
int row = 0; row < results_per_quadgroup; row++) {
413 const device uint8_t* wl =
414 (
const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
415 const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
416 const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
420 if (row * quads_per_simd + out_row < out_vec_size) {
425 for (
int row = 0; row < results_per_quadgroup; row++) {
426 result[row] = quad_sum(result[row]);
427 if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
428 y[row * quads_per_simd] =
static_cast<T
>(result[row]);
433template <
typename T,
int group_size,
int bits>
435 const device uint32_t* w,
436 const device T* scales,
437 const device T* biases,
440 const constant
int& in_vec_size,
441 const constant
int& out_vec_size,
442 uint3 tid [[threadgroup_position_in_grid]],
443 uint simd_gid [[simdgroup_index_in_threadgroup]],
444 uint simd_lid [[thread_index_in_simdgroup]]) {
445 constexpr int packs_per_thread = bits > 2 ? 2 : 1;
446 constexpr int num_simdgroups = 2;
447 constexpr int results_per_simdgroup = 4;
448 constexpr int pack_factor = 32 / bits;
449 constexpr int values_per_thread = pack_factor * packs_per_thread;
450 constexpr int block_size = values_per_thread *
SIMD_SIZE;
451 constexpr int scale_step_per_thread = group_size / values_per_thread;
455 thread U x_thread[values_per_thread];
456 thread U result[results_per_simdgroup] = {0};
459 const int in_vec_size_w = in_vec_size / pack_factor;
460 const int in_vec_size_g = in_vec_size / group_size;
461 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
462 simd_gid * results_per_simdgroup;
463 w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
464 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
465 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
466 x += tid.y * in_vec_size + simd_lid * values_per_thread;
467 y += tid.y * out_vec_size + out_row;
469 for (
int k = 0; k < in_vec_size; k += block_size) {
472 for (
int row = 0; row < results_per_simdgroup; row++) {
473 const device uint8_t* wl =
474 (
const device uint8_t*)(w + row * in_vec_size_w);
475 const device T* sl = scales + row * in_vec_size_g;
476 const device T* bl = biases + row * in_vec_size_g;
483 w += block_size / pack_factor;
484 scales += block_size / group_size;
485 biases += block_size / group_size;
489 for (
int row = 0; row < results_per_simdgroup; row++) {
490 result[row] =
simd_sum(result[row]);
492 y[row] =
static_cast<T
>(result[row]);
497template <
typename T,
int group_size,
int bits>
499 const device uint32_t* w,
500 const device T* scales,
501 const device T* biases,
504 const constant
int& in_vec_size,
505 const constant
int& out_vec_size,
506 uint3 tid [[threadgroup_position_in_grid]],
507 uint simd_gid [[simdgroup_index_in_threadgroup]],
508 uint simd_lid [[thread_index_in_simdgroup]]) {
509 constexpr int num_simdgroups = 2;
510 constexpr int results_per_simdgroup = 4;
511 constexpr int packs_per_thread = 1;
512 constexpr int pack_factor = 32 / bits;
513 constexpr int values_per_thread = pack_factor * packs_per_thread;
514 constexpr int block_size = values_per_thread *
SIMD_SIZE;
515 constexpr int scale_step_per_thread = group_size / values_per_thread;
519 thread U x_thread[values_per_thread];
520 thread U result[results_per_simdgroup] = {0};
523 const int in_vec_size_w = in_vec_size / pack_factor;
524 const int in_vec_size_g = in_vec_size / group_size;
525 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
526 simd_gid * results_per_simdgroup;
527 const int used_out_row =
min(out_vec_size - results_per_simdgroup, out_row);
529 if (out_row >= out_vec_size) {
535 if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
536 w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
537 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
538 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
539 x += tid.y * in_vec_size + simd_lid * values_per_thread;
540 y += tid.y * out_vec_size + out_row;
543 for (; k < in_vec_size - block_size; k += block_size) {
546 for (
int row = 0; out_row + row < out_vec_size; row++) {
547 const device uint8_t* wl =
548 (
const device uint8_t*)(w + row * in_vec_size_w);
549 const device T* sl = scales + row * in_vec_size_g;
550 const device T* bl = biases + row * in_vec_size_g;
558 w += block_size / pack_factor;
559 scales += block_size / group_size;
560 biases += block_size / group_size;
563 const int remaining = clamp(
564 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
570 for (
int row = 0; out_row + row < out_vec_size; row++) {
571 const device uint8_t* wl =
572 (
const device uint8_t*)(w + row * in_vec_size_w);
573 const device T* sl = scales + row * in_vec_size_g;
574 const device T* bl = biases + row * in_vec_size_g;
581 for (
int row = 0; out_row + row < out_vec_size; row++) {
582 result[row] =
simd_sum(result[row]);
584 y[row] =
static_cast<T
>(result[row]);
591 w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread;
592 scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
593 biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
594 x += tid.y * in_vec_size + simd_lid * values_per_thread;
595 y += tid.y * out_vec_size + used_out_row;
598 for (; k < in_vec_size - block_size; k += block_size) {
601 for (
int row = 0; row < results_per_simdgroup; row++) {
602 const device uint8_t* wl =
603 (
const device uint8_t*)(w + row * in_vec_size_w);
604 const device T* sl = scales + row * in_vec_size_g;
605 const device T* bl = biases + row * in_vec_size_g;
613 w += block_size / pack_factor;
614 scales += block_size / group_size;
615 biases += block_size / group_size;
618 const int remaining = clamp(
619 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
625 for (
int row = 0; row < results_per_simdgroup; row++) {
626 const device uint8_t* wl =
627 (
const device uint8_t*)(w + row * in_vec_size_w);
628 const device T* sl = scales + row * in_vec_size_g;
629 const device T* bl = biases + row * in_vec_size_g;
634 wl, x_thread, s, b, sum, remaining);
637 for (
int row = 0; row < results_per_simdgroup; row++) {
638 result[row] =
simd_sum(result[row]);
640 y[row] =
static_cast<T
>(result[row]);
646template <
typename T, const
int group_size, const
int bits>
648 const device uint32_t* w,
649 const device T* scales,
650 const device T* biases,
653 const constant
int& in_vec_size,
654 const constant
int& out_vec_size,
655 uint3 tid [[threadgroup_position_in_grid]],
656 uint simd_gid [[simdgroup_index_in_threadgroup]],
657 uint simd_lid [[thread_index_in_simdgroup]]) {
658 constexpr int num_simdgroups = 2;
659 constexpr int pack_factor = 32 / bits;
660 constexpr int tn = 32 / pack_factor;
668 thread vec_w w_local;
669 thread U result[tn * pack_factor] = {0};
672 thread U x_local = 0;
675 const int out_vec_size_w = out_vec_size / pack_factor;
676 const int out_vec_size_g = out_vec_size / group_size;
678 tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
679 w += out_col / pack_factor + simd_lid * out_vec_size_w;
680 scales += out_col / group_size + simd_lid * out_vec_size_g;
681 biases += out_col / group_size + simd_lid * out_vec_size_g;
682 x += tid.y * in_vec_size + simd_lid;
683 y += tid.y * out_vec_size + out_col;
685 if (out_col >= out_vec_size) {
690 int remaining = in_vec_size % blocksize;
691 if (remaining == 0) {
692 for (
int i = 0; i < in_vec_size; i += blocksize) {
696 w_local = *((device vec_w*)w);
699 (thread uint8_t*)&w_local, x_local, scale, bias, result);
702 scales += blocksize * out_vec_size_g;
703 biases += blocksize * out_vec_size_g;
704 w += blocksize * out_vec_size_w;
707 for (
int i = blocksize; i < in_vec_size; i += blocksize) {
711 w_local = *((device vec_w*)w);
714 (thread uint8_t*)&w_local, x_local, scale, bias, result);
717 scales += blocksize * out_vec_size_g;
718 biases += blocksize * out_vec_size_g;
719 w += blocksize * out_vec_size_w;
721 if (
static_cast<int>(simd_lid) < remaining) {
725 w_local = *((device vec_w*)w);
732 (thread uint8_t*)&w_local, x_local, scale, bias, result);
736#pragma clang loop unroll(full)
737 for (
int k = 0; k < tn * pack_factor; k++) {
743#pragma clang loop unroll(full)
744 for (
int k = 0; k < tn * pack_factor; k++) {
745 y[k] =
static_cast<T
>(result[k]);
752 const int group_size,
754 const bool aligned_N,
759 const device uint32_t* w,
760 const device T* scales,
761 const device T* biases,
766 const constant
int& K,
767 const constant
int& N,
768 const constant
int& M,
769 uint3 tid [[threadgroup_position_in_grid]],
770 uint lid [[thread_index_in_threadgroup]],
771 uint simd_gid [[simdgroup_index_in_threadgroup]],
772 uint simd_lid [[thread_index_in_simdgroup]]) {
773 static_assert(BK >=
SIMD_SIZE,
"BK should be larger than SIMD_SIZE");
774 static_assert(BK %
SIMD_SIZE == 0,
"BK should be divisible by SIMD_SIZE");
778 constexpr int WM = 2;
779 constexpr int WN = 2;
780 constexpr int pack_factor = 32 / bits;
781 constexpr int BK_padded = (BK + 16 /
sizeof(T));
784 using mma_t = mlx::steel::
785 BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
799 const int K_w = K / pack_factor;
800 const int K_g = K / group_size;
801 const int y_row = tid.y * BM;
802 const int y_col = tid.x * BN;
806 scales += y_col * K_g;
807 biases += y_col * K_g;
808 y += y_row * N + y_col;
811 const short num_els =
min(BM, M - y_row);
812 const short num_outs =
min(BN, N - y_col);
813 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
814 loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid);
815 mma_t mma_op(simd_gid, simd_lid);
818 if (!aligned_N && num_outs < BN) {
819 for (
int k = 0; k < K; k += BK) {
820 threadgroup_barrier(mem_flags::mem_threadgroup);
821 loader_x.load_safe(short2(BK, num_els));
822 loader_w.load_safe(short2(BK, num_outs));
823 threadgroup_barrier(mem_flags::mem_threadgroup);
829 for (
int k = 0; k < K; k += BK) {
830 threadgroup_barrier(mem_flags::mem_threadgroup);
831 loader_x.load_safe(short2(BK, num_els));
832 loader_w.load_unsafe();
833 threadgroup_barrier(mem_flags::mem_threadgroup);
840 if (!aligned_N && num_outs < BN) {
841 for (
int k = 0; k < K; k += BK) {
842 threadgroup_barrier(mem_flags::mem_threadgroup);
843 loader_x.load_unsafe();
844 loader_w.load_safe(short2(BK, num_outs));
845 threadgroup_barrier(mem_flags::mem_threadgroup);
851 for (
int k = 0; k < K; k += BK) {
852 threadgroup_barrier(mem_flags::mem_threadgroup);
853 loader_x.load_unsafe();
854 loader_w.load_unsafe();
855 threadgroup_barrier(mem_flags::mem_threadgroup);
864 threadgroup_barrier(mem_flags::mem_threadgroup);
865 if (num_els < BM || num_outs < BN) {
866 mma_op.store_result_safe(y, N, short2(num_outs, num_els));
868 mma_op.store_result(y, N);
874 const int group_size,
880 const device uint32_t* w,
881 const device T* scales,
882 const device T* biases,
887 const constant
int& K,
888 const constant
int& N,
889 const constant
int& M,
890 uint3 tid [[threadgroup_position_in_grid]],
891 uint lid [[thread_index_in_threadgroup]],
892 uint simd_gid [[simdgroup_index_in_threadgroup]],
893 uint simd_lid [[thread_index_in_simdgroup]]) {
894 static_assert(BK >=
SIMD_SIZE,
"BK should be larger than SIMD_SIZE");
895 static_assert(BK %
SIMD_SIZE == 0,
"BK should be divisible by SIMD_SIZE");
899 constexpr int WM = 2;
900 constexpr int WN = 2;
901 constexpr int pack_factor = 32 / bits;
902 constexpr int BK_padded = (BK + 16 /
sizeof(T));
903 constexpr int BN_padded = (BN + 16 /
sizeof(T));
906 using mma_t = mlx::steel::
907 BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
908 using loader_x_t = mlx::steel::
909 BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
921 const int y_row = tid.y * BM;
922 const int y_col = tid.x * BN;
924 w += y_col / pack_factor;
925 scales += y_col / group_size;
926 biases += y_col / group_size;
927 y += y_row * N + y_col;
930 const short num_els =
min(BM, M - y_row);
931 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
932 loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid);
933 mma_t mma_op(simd_gid, simd_lid);
937 const int k_blocks = K / BK;
938 for (
int k = 0; k < k_blocks; k++) {
939 threadgroup_barrier(mem_flags::mem_threadgroup);
940 loader_x.load_safe(short2(BK, num_els));
941 loader_w.load_unsafe();
942 threadgroup_barrier(mem_flags::mem_threadgroup);
947 const short num_k = K - k_blocks * BK;
948 threadgroup_barrier(mem_flags::mem_threadgroup);
949 loader_x.load_safe(short2(num_k, num_els));
950 loader_w.load_safe(short2(BN, num_k));
951 threadgroup_barrier(mem_flags::mem_threadgroup);
954 for (
int k = 0; k < K; k += BK) {
955 threadgroup_barrier(mem_flags::mem_threadgroup);
956 loader_x.load_safe(short2(BK, num_els));
957 loader_w.load_unsafe();
958 threadgroup_barrier(mem_flags::mem_threadgroup);
966 const int k_blocks = K / BK;
967 for (
int k = 0; k < k_blocks; k++) {
968 threadgroup_barrier(mem_flags::mem_threadgroup);
969 loader_x.load_unsafe();
970 loader_w.load_unsafe();
971 threadgroup_barrier(mem_flags::mem_threadgroup);
976 const short num_k = K - k_blocks * BK;
977 threadgroup_barrier(mem_flags::mem_threadgroup);
978 loader_x.load_safe(short2(num_k, BM));
979 loader_w.load_safe(short2(BN, num_k));
980 threadgroup_barrier(mem_flags::mem_threadgroup);
983 for (
int k = 0; k < K; k += BK) {
984 threadgroup_barrier(mem_flags::mem_threadgroup);
985 loader_x.load_unsafe();
986 loader_w.load_unsafe();
987 threadgroup_barrier(mem_flags::mem_threadgroup);
996 threadgroup_barrier(mem_flags::mem_threadgroup);
998 mma_op.store_result_safe(y, N, short2(BN, num_els));
1000 mma_op.store_result(y, N);
1004template <
typename T>
1007 const device uint32_t*& w,
1008 const device T*& scales,
1009 const device T*& biases,
1012 const constant
int& x_batch_ndims,
1013 const constant
int* x_shape,
1014 const constant
size_t* x_strides,
1015 const constant
int& w_batch_ndims,
1016 const constant
int* w_shape,
1017 const constant
size_t* w_strides,
1018 const constant
size_t* s_strides,
1019 const constant
size_t* b_strides,
1020 uint3 tid [[threadgroup_position_in_grid]]) {
1022 uint32_t x_idx = tid.z;
1023 uint32_t w_idx = tid.z;
1024 if (x_batch_ndims == 1) {
1025 x += x_idx * x_strides[0];
1027 x +=
elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1029 if (w_batch_ndims == 1) {
1030 w += w_idx * w_strides[0];
1031 scales += w_idx * s_strides[0];
1032 biases += w_idx * b_strides[0];
1035 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1040 y += tid.z * output_stride;
1043template <
typename T>
1046 const device uint32_t*& w,
1047 const device T*& scales,
1048 const device T*& biases,
1049 const device uint32_t* lhs_indices,
1050 const device uint32_t* rhs_indices,
1053 const constant
int& batch_ndims,
1054 const constant
int* batch_shape,
1055 const constant
size_t* lhs_strides,
1056 const constant
size_t* rhs_strides,
1057 const constant
int& x_batch_ndims,
1058 const constant
int* x_shape,
1059 const constant
size_t* x_strides,
1060 const constant
int& w_batch_ndims,
1061 const constant
int* w_shape,
1062 const constant
size_t* w_strides,
1063 const constant
size_t* s_strides,
1064 const constant
size_t* b_strides,
1065 uint3 tid [[threadgroup_position_in_grid]]) {
1069 if (batch_ndims == 1) {
1070 x_idx = lhs_indices[tid.z * lhs_strides[0]];
1071 w_idx = rhs_indices[tid.z * rhs_strides[0]];
1074 tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
1075 x_idx = lhs_indices[idx.x];
1076 w_idx = rhs_indices[idx.y];
1078 if (x_batch_ndims == 1) {
1079 x += x_idx * x_strides[0];
1081 x +=
elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1083 if (w_batch_ndims == 1) {
1084 w += w_idx * w_strides[0];
1085 scales += w_idx * s_strides[0];
1086 biases += w_idx * b_strides[0];
1089 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1094 y += tid.z * output_stride;
1097template <
typename T,
int group_size,
int bits,
int D,
bool batched>
1099 const device uint32_t* w [[buffer(0)]],
1100 const device T* scales [[buffer(1)]],
1101 const device T* biases [[buffer(2)]],
1102 const device T* x [[buffer(3)]],
1103 device T* y [[buffer(4)]],
1104 const constant
int& in_vec_size [[buffer(5)]],
1105 const constant
int& out_vec_size [[buffer(6)]],
1106 const constant
int& x_batch_ndims [[buffer(7)]],
1107 const constant
int* x_shape [[buffer(8)]],
1108 const constant
size_t* x_strides [[buffer(9)]],
1109 const constant
int& w_batch_ndims [[buffer(10)]],
1110 const constant
int* w_shape [[buffer(11)]],
1111 const constant
size_t* w_strides [[buffer(12)]],
1112 const constant
size_t* s_strides [[buffer(13)]],
1113 const constant
size_t* b_strides [[buffer(14)]],
1114 uint3 tid [[threadgroup_position_in_grid]],
1115 uint quad_gid [[quadgroup_index_in_threadgroup]],
1116 uint quad_lid [[thread_index_in_quadgroup]]) {
1148template <
typename T,
int group_size,
int bits,
bool batched>
1150 const device uint32_t* w [[buffer(0)]],
1151 const device T* scales [[buffer(1)]],
1152 const device T* biases [[buffer(2)]],
1153 const device T* x [[buffer(3)]],
1154 device T* y [[buffer(4)]],
1155 const constant
int& in_vec_size [[buffer(5)]],
1156 const constant
int& out_vec_size [[buffer(6)]],
1157 const constant
int& x_batch_ndims [[buffer(7)]],
1158 const constant
int* x_shape [[buffer(8)]],
1159 const constant
size_t* x_strides [[buffer(9)]],
1160 const constant
int& w_batch_ndims [[buffer(10)]],
1161 const constant
int* w_shape [[buffer(11)]],
1162 const constant
size_t* w_strides [[buffer(12)]],
1163 const constant
size_t* s_strides [[buffer(13)]],
1164 const constant
size_t* b_strides [[buffer(14)]],
1165 uint3 tid [[threadgroup_position_in_grid]],
1166 uint simd_gid [[simdgroup_index_in_threadgroup]],
1167 uint simd_lid [[thread_index_in_simdgroup]]) {
1199template <
typename T, const
int group_size, const
int bits,
bool batched>
1201 const device uint32_t* w [[buffer(0)]],
1202 const device T* scales [[buffer(1)]],
1203 const device T* biases [[buffer(2)]],
1204 const device T* x [[buffer(3)]],
1205 device T* y [[buffer(4)]],
1206 const constant
int& in_vec_size [[buffer(5)]],
1207 const constant
int& out_vec_size [[buffer(6)]],
1208 const constant
int& x_batch_ndims [[buffer(7)]],
1209 const constant
int* x_shape [[buffer(8)]],
1210 const constant
size_t* x_strides [[buffer(9)]],
1211 const constant
int& w_batch_ndims [[buffer(10)]],
1212 const constant
int* w_shape [[buffer(11)]],
1213 const constant
size_t* w_strides [[buffer(12)]],
1214 const constant
size_t* s_strides [[buffer(13)]],
1215 const constant
size_t* b_strides [[buffer(14)]],
1216 uint3 tid [[threadgroup_position_in_grid]],
1217 uint simd_gid [[simdgroup_index_in_threadgroup]],
1218 uint simd_lid [[thread_index_in_simdgroup]]) {
1250template <
typename T, const
int group_size, const
int bits,
bool batched>
1252 const device uint32_t* w [[buffer(0)]],
1253 const device T* scales [[buffer(1)]],
1254 const device T* biases [[buffer(2)]],
1255 const device T* x [[buffer(3)]],
1256 device T* y [[buffer(4)]],
1257 const constant
int& in_vec_size [[buffer(5)]],
1258 const constant
int& out_vec_size [[buffer(6)]],
1259 const constant
int& x_batch_ndims [[buffer(7)]],
1260 const constant
int* x_shape [[buffer(8)]],
1261 const constant
size_t* x_strides [[buffer(9)]],
1262 const constant
int& w_batch_ndims [[buffer(10)]],
1263 const constant
int* w_shape [[buffer(11)]],
1264 const constant
size_t* w_strides [[buffer(12)]],
1265 const constant
size_t* s_strides [[buffer(13)]],
1266 const constant
size_t* b_strides [[buffer(14)]],
1267 uint3 tid [[threadgroup_position_in_grid]],
1268 uint simd_gid [[simdgroup_index_in_threadgroup]],
1269 uint simd_lid [[thread_index_in_simdgroup]]) {
1303 const int group_size,
1305 const bool aligned_N,
1311 const device uint32_t* w [[buffer(0)]],
1312 const device T* scales [[buffer(1)]],
1313 const device T* biases [[buffer(2)]],
1314 const device T* x [[buffer(3)]],
1315 device T* y [[buffer(4)]],
1316 const constant
int& K [[buffer(5)]],
1317 const constant
int& N [[buffer(6)]],
1318 const constant
int& M [[buffer(7)]],
1319 const constant
int& x_batch_ndims [[buffer(8)]],
1320 const constant
int* x_shape [[buffer(9)]],
1321 const constant
size_t* x_strides [[buffer(10)]],
1322 const constant
int& w_batch_ndims [[buffer(11)]],
1323 const constant
int* w_shape [[buffer(12)]],
1324 const constant
size_t* w_strides [[buffer(13)]],
1325 const constant
size_t* s_strides [[buffer(14)]],
1326 const constant
size_t* b_strides [[buffer(15)]],
1327 uint3 tid [[threadgroup_position_in_grid]],
1328 uint lid [[thread_index_in_threadgroup]],
1329 uint simd_gid [[simdgroup_index_in_threadgroup]],
1330 uint simd_lid [[thread_index_in_simdgroup]]) {
1333 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1335 threadgroup T Xs[BM * BK_padded];
1336 threadgroup T Ws[BN * BK_padded];
1357 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1362 const int group_size,
1369 const device uint32_t* w [[buffer(0)]],
1370 const device T* scales [[buffer(1)]],
1371 const device T* biases [[buffer(2)]],
1372 const device T* x [[buffer(3)]],
1373 device T* y [[buffer(4)]],
1374 const constant
int& K [[buffer(5)]],
1375 const constant
int& N [[buffer(6)]],
1376 const constant
int& M [[buffer(7)]],
1377 const constant
int& x_batch_ndims [[buffer(8)]],
1378 const constant
int* x_shape [[buffer(9)]],
1379 const constant
size_t* x_strides [[buffer(10)]],
1380 const constant
int& w_batch_ndims [[buffer(11)]],
1381 const constant
int* w_shape [[buffer(12)]],
1382 const constant
size_t* w_strides [[buffer(13)]],
1383 const constant
size_t* s_strides [[buffer(14)]],
1384 const constant
size_t* b_strides [[buffer(15)]],
1385 uint3 tid [[threadgroup_position_in_grid]],
1386 uint lid [[thread_index_in_threadgroup]],
1387 uint simd_gid [[simdgroup_index_in_threadgroup]],
1388 uint simd_lid [[thread_index_in_simdgroup]]) {
1391 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1392 constexpr int BN_padded = (BN + 16 /
sizeof(T));
1394 threadgroup T Xs[BM * BK_padded];
1395 threadgroup T Ws[BK * BN_padded];
1417 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1420template <
typename T,
int group_size,
int bits>
1422 const device uint32_t* w [[buffer(0)]],
1423 const device T* scales [[buffer(1)]],
1424 const device T* biases [[buffer(2)]],
1425 const device T* x [[buffer(3)]],
1426 device T* y [[buffer(4)]],
1427 const constant
int& in_vec_size [[buffer(5)]],
1428 const constant
int& out_vec_size [[buffer(6)]],
1429 const constant
int& x_batch_ndims [[buffer(7)]],
1430 const constant
int* x_shape [[buffer(8)]],
1431 const constant
size_t* x_strides [[buffer(9)]],
1432 const constant
int& w_batch_ndims [[buffer(10)]],
1433 const constant
int* w_shape [[buffer(11)]],
1434 const constant
size_t* w_strides [[buffer(12)]],
1435 const constant
size_t* s_strides [[buffer(13)]],
1436 const constant
size_t* b_strides [[buffer(14)]],
1437 const constant
int& batch_ndims [[buffer(15)]],
1438 const constant
int* batch_shape [[buffer(16)]],
1439 const device uint32_t* lhs_indices [[buffer(17)]],
1440 const device uint32_t* rhs_indices [[buffer(18)]],
1441 const constant
size_t* lhs_strides [[buffer(19)]],
1442 const constant
size_t* rhs_strides [[buffer(20)]],
1443 uint3 tid [[threadgroup_position_in_grid]],
1444 uint simd_gid [[simdgroup_index_in_threadgroup]],
1445 uint simd_lid [[thread_index_in_simdgroup]]) {
1481template <
typename T,
int group_size,
int bits>
1483 const device uint32_t* w [[buffer(0)]],
1484 const device T* scales [[buffer(1)]],
1485 const device T* biases [[buffer(2)]],
1486 const device T* x [[buffer(3)]],
1487 device T* y [[buffer(4)]],
1488 const constant
int& in_vec_size [[buffer(5)]],
1489 const constant
int& out_vec_size [[buffer(6)]],
1490 const constant
int& x_batch_ndims [[buffer(7)]],
1491 const constant
int* x_shape [[buffer(8)]],
1492 const constant
size_t* x_strides [[buffer(9)]],
1493 const constant
int& w_batch_ndims [[buffer(10)]],
1494 const constant
int* w_shape [[buffer(11)]],
1495 const constant
size_t* w_strides [[buffer(12)]],
1496 const constant
size_t* s_strides [[buffer(13)]],
1497 const constant
size_t* b_strides [[buffer(14)]],
1498 const constant
int& batch_ndims [[buffer(15)]],
1499 const constant
int* batch_shape [[buffer(16)]],
1500 const device uint32_t* lhs_indices [[buffer(17)]],
1501 const device uint32_t* rhs_indices [[buffer(18)]],
1502 const constant
size_t* lhs_strides [[buffer(19)]],
1503 const constant
size_t* rhs_strides [[buffer(20)]],
1504 uint3 tid [[threadgroup_position_in_grid]],
1505 uint simd_gid [[simdgroup_index_in_threadgroup]],
1506 uint simd_lid [[thread_index_in_simdgroup]]) {
1542template <
typename T,
int group_size,
int bits>
1544 const device uint32_t* w [[buffer(0)]],
1545 const device T* scales [[buffer(1)]],
1546 const device T* biases [[buffer(2)]],
1547 const device T* x [[buffer(3)]],
1548 device T* y [[buffer(4)]],
1549 const constant
int& in_vec_size [[buffer(5)]],
1550 const constant
int& out_vec_size [[buffer(6)]],
1551 const constant
int& x_batch_ndims [[buffer(7)]],
1552 const constant
int* x_shape [[buffer(8)]],
1553 const constant
size_t* x_strides [[buffer(9)]],
1554 const constant
int& w_batch_ndims [[buffer(10)]],
1555 const constant
int* w_shape [[buffer(11)]],
1556 const constant
size_t* w_strides [[buffer(12)]],
1557 const constant
size_t* s_strides [[buffer(13)]],
1558 const constant
size_t* b_strides [[buffer(14)]],
1559 const constant
int& batch_ndims [[buffer(15)]],
1560 const constant
int* batch_shape [[buffer(16)]],
1561 const device uint32_t* lhs_indices [[buffer(17)]],
1562 const device uint32_t* rhs_indices [[buffer(18)]],
1563 const constant
size_t* lhs_strides [[buffer(19)]],
1564 const constant
size_t* rhs_strides [[buffer(20)]],
1565 uint3 tid [[threadgroup_position_in_grid]],
1566 uint simd_gid [[simdgroup_index_in_threadgroup]],
1567 uint simd_lid [[thread_index_in_simdgroup]]) {
1605 const int group_size,
1607 const bool aligned_N,
1612 const device uint32_t* w [[buffer(0)]],
1613 const device T* scales [[buffer(1)]],
1614 const device T* biases [[buffer(2)]],
1615 const device T* x [[buffer(3)]],
1616 device T* y [[buffer(4)]],
1617 const constant
int& K [[buffer(5)]],
1618 const constant
int& N [[buffer(6)]],
1619 const constant
int& M [[buffer(7)]],
1620 const constant
int& x_batch_ndims [[buffer(8)]],
1621 const constant
int* x_shape [[buffer(9)]],
1622 const constant
size_t* x_strides [[buffer(10)]],
1623 const constant
int& w_batch_ndims [[buffer(11)]],
1624 const constant
int* w_shape [[buffer(12)]],
1625 const constant
size_t* w_strides [[buffer(13)]],
1626 const constant
size_t* s_strides [[buffer(14)]],
1627 const constant
size_t* b_strides [[buffer(15)]],
1628 const constant
int& batch_ndims [[buffer(16)]],
1629 const constant
int* batch_shape [[buffer(17)]],
1630 const device uint32_t* lhs_indices [[buffer(18)]],
1631 const device uint32_t* rhs_indices [[buffer(19)]],
1632 const constant
size_t* lhs_strides [[buffer(20)]],
1633 const constant
size_t* rhs_strides [[buffer(21)]],
1634 uint3 tid [[threadgroup_position_in_grid]],
1635 uint lid [[thread_index_in_threadgroup]],
1636 uint simd_gid [[simdgroup_index_in_threadgroup]],
1637 uint simd_lid [[thread_index_in_simdgroup]]) {
1640 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1642 threadgroup T Xs[BM * BK_padded];
1643 threadgroup T Ws[BN * BK_padded];
1668 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1673 const int group_size,
1679 const device uint32_t* w [[buffer(0)]],
1680 const device T* scales [[buffer(1)]],
1681 const device T* biases [[buffer(2)]],
1682 const device T* x [[buffer(3)]],
1683 device T* y [[buffer(4)]],
1684 const constant
int& K [[buffer(5)]],
1685 const constant
int& N [[buffer(6)]],
1686 const constant
int& M [[buffer(7)]],
1687 const constant
int& x_batch_ndims [[buffer(8)]],
1688 const constant
int* x_shape [[buffer(9)]],
1689 const constant
size_t* x_strides [[buffer(10)]],
1690 const constant
int& w_batch_ndims [[buffer(11)]],
1691 const constant
int* w_shape [[buffer(12)]],
1692 const constant
size_t* w_strides [[buffer(13)]],
1693 const constant
size_t* s_strides [[buffer(14)]],
1694 const constant
size_t* b_strides [[buffer(15)]],
1695 const constant
int& batch_ndims [[buffer(16)]],
1696 const constant
int* batch_shape [[buffer(17)]],
1697 const device uint32_t* lhs_indices [[buffer(18)]],
1698 const device uint32_t* rhs_indices [[buffer(19)]],
1699 const constant
size_t* lhs_strides [[buffer(20)]],
1700 const constant
size_t* rhs_strides [[buffer(21)]],
1701 uint3 tid [[threadgroup_position_in_grid]],
1702 uint lid [[thread_index_in_threadgroup]],
1703 uint simd_gid [[simdgroup_index_in_threadgroup]],
1704 uint simd_lid [[thread_index_in_simdgroup]]) {
1707 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1708 constexpr int BN_padded = (BN + 16 /
sizeof(T));
1710 threadgroup T Xs[BM * BK_padded];
1711 threadgroup T Ws[BK * BN_padded];
1736 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1739template <
typename T, const
int group_size, const
int bits>
1741 const device T* w [[buffer(0)]],
1742 device uint8_t* out [[buffer(1)]],
1743 device T* scales [[buffer(2)]],
1744 device T* biases [[buffer(3)]],
1745 uint2 index [[thread_position_in_grid]],
1746 uint2 grid_dim [[threads_per_grid]]) {
1747 constexpr T eps = T(1e-7);
1749 constexpr int uint8_bits = 8;
1750 constexpr T n_bins = (1 << bits) - 1;
1751 constexpr int packs_per_int = uint8_bits / bits;
1752 constexpr int values_per_reduce = group_size /
simd_size;
1753 constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
1754 constexpr int writes_per_pack =
1755 writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
1759 "Group size must be divisible by simd size.");
1761 size_t offset = index.x + grid_dim.x * size_t(index.y);
1762 size_t in_index = offset * values_per_reduce;
1763 size_t out_index = offset * writes_per_pack;
1765 T w_thread[values_per_reduce];
1769#pragma clang loop unroll(full)
1770 for (
int i = 0; i < values_per_reduce; i++) {
1771 T val = w[in_index + i];
1773 w_min =
min(w_min, val);
1774 w_max =
max(w_max, val);
1780 T scale =
max((w_max - w_min) / n_bins, eps);
1781 bool side =
abs(w_min) >
abs(w_max);
1782 scale = side ? scale : -scale;
1783 T edge = side ? w_min : w_max;
1784 T q0 =
round(edge / scale);
1785 bool at_zero = q0 == 0.0f;
1786 scale = at_zero ? scale : edge / q0;
1787 T bias = at_zero ? T(0) : edge;
1790 size_t gindex = in_index / group_size;
1791 if (in_index % group_size == 0) {
1792 scales[gindex] = scale;
1793 biases[gindex] = bias;
1797#pragma clang loop unroll(full)
1798 for (
int i = 0; i < values_per_reduce; i++) {
1799 uint8_t val =
min(
round((w_thread[i] - bias) / scale), n_bins);
1803 output += val << (bits * (i % packs_per_int));
1806 if (packs_per_int < values_per_reduce &&
1807 i % packs_per_int == packs_per_int - 1) {
1808 out[out_index + i / packs_per_int] = output;
1811#pragma clang loop unroll(full)
1812 for (
int j = 0; j < writes_per_reduce - 1; j++) {
1814 output += sval << (bits * (values_per_reduce + j + i));
1818 if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
1819 out[out_index / writes_per_reduce] = output;
1823template <
typename T, const
int group_size, const
int bits>
1825 const device T* w [[buffer(0)]],
1826 const device T* scales [[buffer(1)]],
1827 const device T* biases [[buffer(2)]],
1828 device uint8_t* out [[buffer(3)]],
1829 uint2 index [[thread_position_in_grid]],
1830 uint2 grid_dim [[threads_per_grid]]) {
1831 constexpr int uint8_bits = 8;
1832 constexpr int packs_per_int = uint8_bits / bits;
1833 constexpr T n_bins = (1 << bits) - 1;
1835 size_t offset = index.x + grid_dim.x * size_t(index.y);
1836 size_t in_index = offset * packs_per_int;
1837 size_t gindex = in_index / group_size;
1839 T scale = scales[gindex];
1840 T bias = biases[gindex];
1843#pragma clang loop unroll(full)
1844 for (
int i = 0; i < packs_per_int; i++) {
1845 uint8_t val =
min(
round((w[in_index + i] - bias) / scale), n_bins);
1849 output += val << (bits * i);
1852 out[offset] = output;
1855template <
typename T, const
int group_size, const
int bits>
1857 const device uint8_t* w [[buffer(0)]],
1858 const device T* scales [[buffer(1)]],
1859 const device T* biases [[buffer(2)]],
1860 device T* out [[buffer(3)]],
1861 uint2 index [[thread_position_in_grid]],
1862 uint2 grid_dim [[threads_per_grid]]) {
1863 constexpr int uint8_bits = 8;
1864 constexpr int packs_per_int = uint8_bits / bits;
1866 size_t offset = index.x + grid_dim.x * size_t(index.y);
1867 size_t oindex = offset * packs_per_int;
1868 size_t gindex = oindex / group_size;
1869 T scale = scales[gindex];
1870 T bias = biases[gindex];
1871 uint val = w[offset];
1873#pragma clang loop unroll(full)
1874 for (
int i = 0; i < packs_per_int; i++) {
1877 d = (val >> (bits * i)) & 0x03;
1878 }
else if (bits == 4) {
1879 d = (val >> (bits * i)) & 0x0f;
1880 }
else if (bits == 8) {
1883 out[oindex + i] = scale * d + bias;
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
#define MLX_MTL_CONST
Definition quantized.h:8
U qdot_safe(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum, int N)
Definition quantized.h:142
METAL_FUNC void qmm_n_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:879
void bs_qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1678
void qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1368
void affine_quantize(const device T *w, device uint8_t *out, device T *scales, device T *biases, uint2 index, uint2 grid_dim)
Definition quantized.h:1740
METAL_FUNC void qvm_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:647
void bs_qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1421
void affine_dequantize(const device uint8_t *w, const device T *scales, const device T *biases, device T *out, uint2 index, uint2 grid_dim)
Definition quantized.h:1856
static constant constexpr const int SIMD_SIZE
Definition quantized.h:10
void qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1200
void bs_qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1543
void affine_quantize_scales_biases(const device T *w, const device T *scales, const device T *biases, device uint8_t *out, uint2 index, uint2 grid_dim)
Definition quantized.h:1824
void qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1149
void qmv_quad(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:1098
static constant constexpr const int QUAD_SIZE
Definition quantized.h:11
U load_vector(const device T *x, thread U *x_thread)
Definition quantized.h:14
METAL_FUNC void qmv_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:498
U load_vector_safe(const device T *x, thread U *x_thread, int N)
Definition quantized.h:52
void bs_qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1611
U qdot(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum)
Definition quantized.h:99
METAL_FUNC void qmv_fast_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:434
void qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1310
METAL_FUNC void adjust_matrix_offsets(const device T *&x, const device uint32_t *&w, const device T *&scales, const device T *&biases, device T *&y, int output_stride, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid)
Definition quantized.h:1005
void bs_qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1482
METAL_FUNC void qmv_quad_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:376
void qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1251
void qouter(const thread uint8_t *w, U x, U scale, U bias, thread U *result)
Definition quantized.h:187
void dequantize(const device uint8_t *w, U scale, U bias, threadgroup U *w_local)
Definition quantized.h:219
METAL_FUNC void qmm_t_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:758
Definition quantized.h:262
const int group_stride
Definition quantized.h:282
static constant constexpr const short BCOLS_PACKED
Definition quantized.h:274
const device T * biases
Definition quantized.h:291
short group_step_cnt
Definition quantized.h:281
static constant constexpr const short group_steps
Definition quantized.h:277
const short thread_idx
Definition quantized.h:284
const device T * scales
Definition quantized.h:290
static constant constexpr const short n_reads
Definition quantized.h:275
void next()
Definition quantized.h:354
void load_safe(short2 src_tile_dim) const
Definition quantized.h:327
const int src_ld
Definition quantized.h:279
const short bi
Definition quantized.h:285
void load_unsafe() const
Definition quantized.h:314
static constant constexpr const short pack_factor
Definition quantized.h:273
threadgroup T * dst
Definition quantized.h:288
const int tile_stride
Definition quantized.h:280
const device uint32_t * src
Definition quantized.h:289
const short bj
Definition quantized.h:286
QuantizedBlockLoader(const device uint32_t *src_, const device T *scales_, const device T *biases_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition quantized.h:293