3#include <metal_simdgroup>
8#define MLX_MTL_CONST static constant constexpr const
12template <
typename T,
typename U,
int values_per_thread,
int bits>
15 bits == 2 || bits == 4 || bits == 8,
16 "Template undefined for bits not in {2, 4, 8}");
21 for (
int i = 0; i < values_per_thread; i += 4) {
22 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
24 x_thread[i + 1] = x[i + 1] / 4.0f;
25 x_thread[i + 2] = x[i + 2] / 16.0f;
26 x_thread[i + 3] = x[i + 3] / 64.0f;
31 for (
int i = 0; i < values_per_thread; i += 4) {
32 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
34 x_thread[i + 1] = x[i + 1] / 16.0f;
35 x_thread[i + 2] = x[i + 2] / 256.0f;
36 x_thread[i + 3] = x[i + 3] / 4096.0f;
41 for (
int i = 0; i < values_per_thread; i++) {
50template <
typename T,
typename U,
int values_per_thread,
int bits>
53 bits == 2 || bits == 4 || bits == 8,
54 "Template undefined for bits not in {2, 4, 8}");
59 for (
int i = 0; i < N; i += 4) {
60 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
62 x_thread[i + 1] = x[i + 1] / 4.0f;
63 x_thread[i + 2] = x[i + 2] / 16.0f;
64 x_thread[i + 3] = x[i + 3] / 64.0f;
66 for (
int i = N; i < values_per_thread; i++) {
72 for (
int i = 0; i < N; i += 4) {
73 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
75 x_thread[i + 1] = x[i + 1] / 16.0f;
76 x_thread[i + 2] = x[i + 2] / 256.0f;
77 x_thread[i + 3] = x[i + 3] / 4096.0f;
79 for (
int i = N; i < values_per_thread; i++) {
85 for (
int i = 0; i < N; i++) {
89 for (
int i = N; i < values_per_thread; i++) {
97template <
typename U,
int values_per_thread,
int bits>
99 const device uint8_t* w,
100 const thread U* x_thread,
105 bits == 2 || bits == 4 || bits == 8,
106 "Template undefined for bits not in {2, 4, 8}");
111 for (
int i = 0; i < (values_per_thread / 4); i++) {
113 (x_thread[4 * i] * (w[i] & 0x03) +
114 x_thread[4 * i + 1] * (w[i] & 0x0c) +
115 x_thread[4 * i + 2] * (w[i] & 0x30) +
116 x_thread[4 * i + 3] * (w[i] & 0xc0));
120 else if (bits == 4) {
121 const device uint16_t* ws = (
const device uint16_t*)w;
122 for (
int i = 0; i < (values_per_thread / 4); i++) {
124 (x_thread[4 * i] * (ws[i] & 0x000f) +
125 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
126 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
127 x_thread[4 * i + 3] * (ws[i] & 0xf000));
131 else if (bits == 8) {
132 for (
int i = 0; i < values_per_thread; i++) {
133 accum += x_thread[i] * w[i];
137 return scale * accum + sum * bias;
140template <
typename U,
int values_per_thread,
int bits>
142 const device uint8_t* w,
143 const thread U* x_thread,
149 bits == 2 || bits == 4 || bits == 8,
150 "Template undefined for bits not in {2, 4, 8}");
155 for (
int i = 0; i < (N / 4); i++) {
157 (x_thread[4 * i] * (w[i] & 0x03) +
158 x_thread[4 * i + 1] * (w[i] & 0x0c) +
159 x_thread[4 * i + 2] * (w[i] & 0x30) +
160 x_thread[4 * i + 3] * (w[i] & 0xc0));
164 else if (bits == 4) {
165 const device uint16_t* ws = (
const device uint16_t*)w;
166 for (
int i = 0; i < (N / 4); i++) {
168 (x_thread[4 * i] * (ws[i] & 0x000f) +
169 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
170 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
171 x_thread[4 * i + 3] * (ws[i] & 0xf000));
175 else if (bits == 8) {
176 for (
int i = 0; i < N; i++) {
177 accum += x_thread[i] * w[i];
181 return scale * accum + sum * bias;
184template <
typename U,
int values_per_thread,
int bits>
186qouter(
const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
188 bits == 2 || bits == 4 || bits == 8,
189 "Template undefined for bits not in {2, 4, 8}");
192 U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
193 for (
int i = 0; i < (values_per_thread / 4); i++) {
194 result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
195 result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
196 result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
197 result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
201 else if (bits == 4) {
202 U s[2] = {scale, scale / 16.0f};
203 for (
int i = 0; i < (values_per_thread / 2); i++) {
204 result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
205 result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
209 else if (bits == 8) {
210 for (
int i = 0; i < values_per_thread; i++) {
211 result[i] += x * (scale * w[i] + bias);
216template <
typename U,
int N,
int bits>
218dequantize(
const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
220 bits == 2 || bits == 4 || bits == 8,
221 "Template undefined for bits not in {2, 4, 8}");
226 scale /
static_cast<U
>(4.0f),
227 scale /
static_cast<U
>(16.0f),
228 scale /
static_cast<U
>(64.0f)};
229 for (
int i = 0; i < (N / 4); i++) {
230 w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
231 w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
232 w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
233 w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
237 else if (bits == 4) {
238 U s[2] = {scale, scale /
static_cast<U
>(16.0f)};
239 for (
int i = 0; i < (N / 2); i++) {
240 w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
241 w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
245 else if (bits == 8) {
246 for (
int i = 0; i < N; i++) {
247 w_local[i] = scale * w[i] + bias;
264 "The group size should be larger than the columns");
266 group_size % BCOLS == 0,
267 "The group size should be divisible by the columns");
269 bits == 2 || bits == 4 || bits == 8,
270 "Template undefined for bits not in {2, 4, 8}");
288 const device uint32_t*
src;
293 const device uint32_t* src_,
294 const device T* scales_,
295 const device T* biases_,
298 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
299 ushort simd_lane_id [[thread_index_in_simdgroup]])
305 thread_idx(simd_group_id * 32 + simd_lane_id),
320 for (
int i = 0; i <
n_reads; i++) {
331 if (reduction_dim == 1 &&
bi >= src_tile_dim.y) {
338 if (reduction_dim == 0 &&
bi >= src_tile_dim.x) {
347 for (
int i = 0; i <
n_reads; i++) {
355 if (reduction_dim == 1) {
374template <
typename T,
int group_size,
int bits>
376 const device uint32_t* w,
377 const device T* scales,
378 const device T* biases,
381 const constant
int& in_vec_size,
382 const constant
int& out_vec_size,
383 uint3 tid [[threadgroup_position_in_grid]],
384 uint simd_gid [[simdgroup_index_in_threadgroup]],
385 uint simd_lid [[thread_index_in_simdgroup]]) {
386 constexpr int packs_per_thread = bits > 2 ? 2 : 1;
387 constexpr int num_simdgroups = 2;
388 constexpr int results_per_simdgroup = 4;
389 constexpr int pack_factor = 32 / bits;
390 constexpr int values_per_thread = pack_factor * packs_per_thread;
391 constexpr int block_size = values_per_thread *
SIMD_SIZE;
392 constexpr int scale_step_per_thread = group_size / values_per_thread;
396 thread U x_thread[values_per_thread];
397 thread U result[results_per_simdgroup] = {0};
400 const int in_vec_size_w = in_vec_size / pack_factor;
401 const int in_vec_size_g = in_vec_size / group_size;
402 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
403 simd_gid * results_per_simdgroup;
404 w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
405 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
406 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
407 x += tid.y * in_vec_size + simd_lid * values_per_thread;
408 y += tid.y * out_vec_size + out_row;
410 for (
int k = 0; k < in_vec_size; k += block_size) {
413 for (
int row = 0; row < results_per_simdgroup; row++) {
414 const device uint8_t* wl =
415 (
const device uint8_t*)(w + row * in_vec_size_w);
416 const device T* sl = scales + row * in_vec_size_g;
417 const device T* bl = biases + row * in_vec_size_g;
424 w += block_size / pack_factor;
425 scales += block_size / group_size;
426 biases += block_size / group_size;
430 for (
int row = 0; row < results_per_simdgroup; row++) {
431 result[row] =
simd_sum(result[row]);
433 y[row] =
static_cast<T
>(result[row]);
438template <
typename T,
int group_size,
int bits>
440 const device uint32_t* w,
441 const device T* scales,
442 const device T* biases,
445 const constant
int& in_vec_size,
446 const constant
int& out_vec_size,
447 uint3 tid [[threadgroup_position_in_grid]],
448 uint simd_gid [[simdgroup_index_in_threadgroup]],
449 uint simd_lid [[thread_index_in_simdgroup]]) {
450 constexpr int num_simdgroups = 2;
451 constexpr int results_per_simdgroup = 4;
452 constexpr int packs_per_thread = 1;
453 constexpr int pack_factor = 32 / bits;
454 constexpr int values_per_thread = pack_factor * packs_per_thread;
455 constexpr int block_size = values_per_thread *
SIMD_SIZE;
456 constexpr int scale_step_per_thread = group_size / values_per_thread;
460 thread U x_thread[values_per_thread];
461 thread U result[results_per_simdgroup] = {0};
464 const int in_vec_size_w = in_vec_size / pack_factor;
465 const int in_vec_size_g = in_vec_size / group_size;
466 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
467 simd_gid * results_per_simdgroup;
468 const int used_out_row =
min(out_vec_size - results_per_simdgroup, out_row);
470 if (out_row >= out_vec_size) {
476 if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
477 w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
478 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
479 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
480 x += tid.y * in_vec_size + simd_lid * values_per_thread;
481 y += tid.y * out_vec_size + out_row;
484 for (; k < in_vec_size - block_size; k += block_size) {
487 for (
int row = 0; out_row + row < out_vec_size; row++) {
488 const device uint8_t* wl =
489 (
const device uint8_t*)(w + row * in_vec_size_w);
490 const device T* sl = scales + row * in_vec_size_g;
491 const device T* bl = biases + row * in_vec_size_g;
499 w += block_size / pack_factor;
500 scales += block_size / group_size;
501 biases += block_size / group_size;
504 const int remaining = clamp(
505 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
511 for (
int row = 0; out_row + row < out_vec_size; row++) {
512 const device uint8_t* wl =
513 (
const device uint8_t*)(w + row * in_vec_size_w);
514 const device T* sl = scales + row * in_vec_size_g;
515 const device T* bl = biases + row * in_vec_size_g;
522 for (
int row = 0; out_row + row < out_vec_size; row++) {
523 result[row] =
simd_sum(result[row]);
525 y[row] =
static_cast<T
>(result[row]);
532 w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread;
533 scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
534 biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
535 x += tid.y * in_vec_size + simd_lid * values_per_thread;
536 y += tid.y * out_vec_size + used_out_row;
539 for (; k < in_vec_size - block_size; k += block_size) {
542 for (
int row = 0; row < results_per_simdgroup; row++) {
543 const device uint8_t* wl =
544 (
const device uint8_t*)(w + row * in_vec_size_w);
545 const device T* sl = scales + row * in_vec_size_g;
546 const device T* bl = biases + row * in_vec_size_g;
554 w += block_size / pack_factor;
555 scales += block_size / group_size;
556 biases += block_size / group_size;
559 const int remaining = clamp(
560 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
566 for (
int row = 0; row < results_per_simdgroup; row++) {
567 const device uint8_t* wl =
568 (
const device uint8_t*)(w + row * in_vec_size_w);
569 const device T* sl = scales + row * in_vec_size_g;
570 const device T* bl = biases + row * in_vec_size_g;
575 wl, x_thread, s, b, sum, remaining);
578 for (
int row = 0; row < results_per_simdgroup; row++) {
579 result[row] =
simd_sum(result[row]);
581 y[row] =
static_cast<T
>(result[row]);
587template <
typename T, const
int group_size, const
int bits>
590 const device uint32_t* w,
591 const device T* scales,
592 const device T* biases,
594 const constant
int& in_vec_size,
595 const constant
int& out_vec_size,
596 uint3 tid [[threadgroup_position_in_grid]],
597 uint simd_gid [[simdgroup_index_in_threadgroup]],
598 uint simd_lid [[thread_index_in_simdgroup]]) {
599 constexpr int num_simdgroups = 2;
600 constexpr int pack_factor = 32 / bits;
601 constexpr int tn = 32 / pack_factor;
609 thread vec_w w_local;
610 thread U result[tn * pack_factor] = {0};
613 thread U x_local = 0;
616 const int out_vec_size_w = out_vec_size / pack_factor;
617 const int out_vec_size_g = out_vec_size / group_size;
619 tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
620 w += out_col / pack_factor + simd_lid * out_vec_size_w;
621 scales += out_col / group_size + simd_lid * out_vec_size_g;
622 biases += out_col / group_size + simd_lid * out_vec_size_g;
623 x += tid.y * in_vec_size + simd_lid;
624 y += tid.y * out_vec_size + out_col;
626 if (out_col >= out_vec_size) {
631 int remaining = in_vec_size % blocksize;
632 if (remaining == 0) {
633 for (
int i = 0; i < in_vec_size; i += blocksize) {
637 w_local = *((device vec_w*)w);
640 (thread uint8_t*)&w_local, x_local, scale, bias, result);
643 scales += blocksize * out_vec_size_g;
644 biases += blocksize * out_vec_size_g;
645 w += blocksize * out_vec_size_w;
648 for (
int i = blocksize; i < in_vec_size; i += blocksize) {
652 w_local = *((device vec_w*)w);
655 (thread uint8_t*)&w_local, x_local, scale, bias, result);
658 scales += blocksize * out_vec_size_g;
659 biases += blocksize * out_vec_size_g;
660 w += blocksize * out_vec_size_w;
662 if (
static_cast<int>(simd_lid) < remaining) {
666 w_local = *((device vec_w*)w);
673 (thread uint8_t*)&w_local, x_local, scale, bias, result);
677#pragma clang loop unroll(full)
678 for (
int k = 0; k < tn * pack_factor; k++) {
684#pragma clang loop unroll(full)
685 for (
int k = 0; k < tn * pack_factor; k++) {
686 y[k] =
static_cast<T
>(result[k]);
693 const int group_size,
695 const bool aligned_N,
701 const device uint32_t* w,
702 const device T* scales,
703 const device T* biases,
707 const constant
int& M,
708 const constant
int& N,
709 const constant
int& K,
710 uint3 tid [[threadgroup_position_in_grid]],
711 uint lid [[thread_index_in_threadgroup]],
712 uint simd_gid [[simdgroup_index_in_threadgroup]],
713 uint simd_lid [[thread_index_in_simdgroup]]) {
714 static_assert(BK >=
SIMD_SIZE,
"BK should be larger than SIMD_SIZE");
715 static_assert(BK %
SIMD_SIZE == 0,
"BK should be divisible by SIMD_SIZE");
719 constexpr int WM = 2;
720 constexpr int WN = 2;
721 constexpr int pack_factor = 32 / bits;
722 constexpr int BK_padded = (BK + 16 /
sizeof(T));
725 using mma_t = mlx::steel::
726 BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
740 const int K_w = K / pack_factor;
741 const int K_g = K / group_size;
742 const int y_row = tid.y * BM;
743 const int y_col = tid.x * BN;
747 scales += y_col * K_g;
748 biases += y_col * K_g;
749 y += y_row * N + y_col;
752 const short num_els =
min(BM, M - y_row);
753 const short num_outs =
min(BN, N - y_col);
754 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
755 loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid);
756 mma_t mma_op(simd_gid, simd_lid);
759 if (!aligned_N && num_outs < BN) {
760 for (
int k = 0; k < K; k += BK) {
761 threadgroup_barrier(mem_flags::mem_threadgroup);
762 loader_x.load_safe(short2(BK, num_els));
763 loader_w.load_safe(short2(BK, num_outs));
764 threadgroup_barrier(mem_flags::mem_threadgroup);
770 for (
int k = 0; k < K; k += BK) {
771 threadgroup_barrier(mem_flags::mem_threadgroup);
772 loader_x.load_safe(short2(BK, num_els));
773 loader_w.load_unsafe();
774 threadgroup_barrier(mem_flags::mem_threadgroup);
781 if (!aligned_N && num_outs < BN) {
782 for (
int k = 0; k < K; k += BK) {
783 threadgroup_barrier(mem_flags::mem_threadgroup);
784 loader_x.load_unsafe();
785 loader_w.load_safe(short2(BK, num_outs));
786 threadgroup_barrier(mem_flags::mem_threadgroup);
792 for (
int k = 0; k < K; k += BK) {
793 threadgroup_barrier(mem_flags::mem_threadgroup);
794 loader_x.load_unsafe();
795 loader_w.load_unsafe();
796 threadgroup_barrier(mem_flags::mem_threadgroup);
805 threadgroup_barrier(mem_flags::mem_threadgroup);
806 if (num_els < BM || num_outs < BN) {
807 mma_op.store_result_safe(y, N, short2(num_outs, num_els));
809 mma_op.store_result(y, N);
815 const int group_size,
822 const device uint32_t* w,
823 const device T* scales,
824 const device T* biases,
828 const constant
int& M,
829 const constant
int& N,
830 const constant
int& K,
831 uint3 tid [[threadgroup_position_in_grid]],
832 uint lid [[thread_index_in_threadgroup]],
833 uint simd_gid [[simdgroup_index_in_threadgroup]],
834 uint simd_lid [[thread_index_in_simdgroup]]) {
835 static_assert(BK >=
SIMD_SIZE,
"BK should be larger than SIMD_SIZE");
836 static_assert(BK %
SIMD_SIZE == 0,
"BK should be divisible by SIMD_SIZE");
840 constexpr int WM = 2;
841 constexpr int WN = 2;
842 constexpr int pack_factor = 32 / bits;
843 constexpr int BK_padded = (BK + 16 /
sizeof(T));
844 constexpr int BN_padded = (BN + 16 /
sizeof(T));
847 using mma_t = mlx::steel::
848 BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
849 using loader_x_t = mlx::steel::
850 BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
862 const int y_row = tid.y * BM;
863 const int y_col = tid.x * BN;
865 w += y_col / pack_factor;
866 scales += y_col / group_size;
867 biases += y_col / group_size;
868 y += y_row * N + y_col;
871 const short num_els =
min(BM, M - y_row);
872 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
873 loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid);
874 mma_t mma_op(simd_gid, simd_lid);
878 const int k_blocks = K / BK;
879 for (
int k = 0; k < k_blocks; k++) {
880 threadgroup_barrier(mem_flags::mem_threadgroup);
881 loader_x.load_safe(short2(BK, num_els));
882 loader_w.load_unsafe();
883 threadgroup_barrier(mem_flags::mem_threadgroup);
888 const short num_k = K - k_blocks * BK;
889 threadgroup_barrier(mem_flags::mem_threadgroup);
890 loader_x.load_safe(short2(num_k, num_els));
891 loader_w.load_safe(short2(BN, num_k));
892 threadgroup_barrier(mem_flags::mem_threadgroup);
895 for (
int k = 0; k < K; k += BK) {
896 threadgroup_barrier(mem_flags::mem_threadgroup);
897 loader_x.load_safe(short2(BK, num_els));
898 loader_w.load_unsafe();
899 threadgroup_barrier(mem_flags::mem_threadgroup);
907 const int k_blocks = K / BK;
908 for (
int k = 0; k < k_blocks; k++) {
909 threadgroup_barrier(mem_flags::mem_threadgroup);
910 loader_x.load_unsafe();
911 loader_w.load_unsafe();
912 threadgroup_barrier(mem_flags::mem_threadgroup);
917 const short num_k = K - k_blocks * BK;
918 threadgroup_barrier(mem_flags::mem_threadgroup);
919 loader_x.load_safe(short2(num_k, BM));
920 loader_w.load_safe(short2(BN, num_k));
921 threadgroup_barrier(mem_flags::mem_threadgroup);
924 for (
int k = 0; k < K; k += BK) {
925 threadgroup_barrier(mem_flags::mem_threadgroup);
926 loader_x.load_unsafe();
927 loader_w.load_unsafe();
928 threadgroup_barrier(mem_flags::mem_threadgroup);
937 threadgroup_barrier(mem_flags::mem_threadgroup);
939 mma_op.store_result_safe(y, N, short2(BN, num_els));
941 mma_op.store_result(y, N);
948 const device uint32_t*& w,
949 const device T*& scales,
950 const device T*& biases,
951 const device uint32_t* lhs_indices,
952 const device uint32_t* rhs_indices,
955 const constant
int& batch_ndims,
956 const constant
int* batch_shape,
957 const constant
size_t* lhs_strides,
958 const constant
size_t* rhs_strides,
959 const constant
int& x_batch_ndims,
960 const constant
int* x_shape,
961 const constant
size_t* x_strides,
962 const constant
int& w_batch_ndims,
963 const constant
int* w_shape,
964 const constant
size_t* w_strides,
965 const constant
size_t* s_strides,
966 const constant
size_t* b_strides,
967 uint3 tid [[threadgroup_position_in_grid]]) {
971 if (batch_ndims == 1) {
972 x_idx = lhs_indices[tid.z * lhs_strides[0]];
973 w_idx = rhs_indices[tid.z * rhs_strides[0]];
976 tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
977 x_idx = lhs_indices[idx.x];
978 w_idx = rhs_indices[idx.y];
980 if (x_batch_ndims == 1) {
981 x += x_idx * x_strides[0];
983 x +=
elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
985 if (w_batch_ndims == 1) {
986 w += w_idx * w_strides[0];
987 scales += w_idx * s_strides[0];
988 biases += w_idx * b_strides[0];
991 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
996 y += tid.z * output_stride;
999template <
typename T,
int group_size,
int bits>
1001 const device uint32_t* w [[buffer(0)]],
1002 const device T* scales [[buffer(1)]],
1003 const device T* biases [[buffer(2)]],
1004 const device T* x [[buffer(3)]],
1005 device T* y [[buffer(4)]],
1006 const constant
int& in_vec_size [[buffer(5)]],
1007 const constant
int& out_vec_size [[buffer(6)]],
1008 uint3 tid [[threadgroup_position_in_grid]],
1009 uint simd_gid [[simdgroup_index_in_threadgroup]],
1010 uint simd_lid [[thread_index_in_simdgroup]]) {
1024template <
typename T, const
int group_size, const
int bits>
1026 const device uint32_t* w [[buffer(0)]],
1027 const device T* scales [[buffer(1)]],
1028 const device T* biases [[buffer(2)]],
1029 const device T* x [[buffer(3)]],
1030 device T* y [[buffer(4)]],
1031 const constant
int& in_vec_size [[buffer(5)]],
1032 const constant
int& out_vec_size [[buffer(6)]],
1033 uint3 tid [[threadgroup_position_in_grid]],
1034 uint simd_gid [[simdgroup_index_in_threadgroup]],
1035 uint simd_lid [[thread_index_in_simdgroup]]) {
1049template <
typename T, const
int group_size, const
int bits>
1051 const device T* x [[buffer(0)]],
1052 const device uint32_t* w [[buffer(1)]],
1053 const device T* scales [[buffer(2)]],
1054 const device T* biases [[buffer(3)]],
1055 device T* y [[buffer(4)]],
1056 const constant
int& in_vec_size [[buffer(5)]],
1057 const constant
int& out_vec_size [[buffer(6)]],
1058 uint3 tid [[threadgroup_position_in_grid]],
1059 uint simd_gid [[simdgroup_index_in_threadgroup]],
1060 uint simd_lid [[thread_index_in_simdgroup]]) {
1076 const int group_size,
1078 const bool aligned_N,
1083 const device T* x [[buffer(0)]],
1084 const device uint32_t* w [[buffer(1)]],
1085 const device T* scales [[buffer(2)]],
1086 const device T* biases [[buffer(3)]],
1087 device T* y [[buffer(4)]],
1088 const constant
int& M [[buffer(5)]],
1089 const constant
int& N [[buffer(6)]],
1090 const constant
int& K [[buffer(7)]],
1091 uint3 tid [[threadgroup_position_in_grid]],
1092 uint lid [[thread_index_in_threadgroup]],
1093 uint simd_gid [[simdgroup_index_in_threadgroup]],
1094 uint simd_lid [[thread_index_in_simdgroup]]) {
1097 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1099 threadgroup T Xs[BM * BK_padded];
1100 threadgroup T Ws[BN * BK_padded];
1103 x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
1108 const int group_size,
1114 const device T* x [[buffer(0)]],
1115 const device uint32_t* w [[buffer(1)]],
1116 const device T* scales [[buffer(2)]],
1117 const device T* biases [[buffer(3)]],
1118 device T* y [[buffer(4)]],
1119 const constant
int& M [[buffer(5)]],
1120 const constant
int& N [[buffer(6)]],
1121 const constant
int& K [[buffer(7)]],
1122 uint3 tid [[threadgroup_position_in_grid]],
1123 uint lid [[thread_index_in_threadgroup]],
1124 uint simd_gid [[simdgroup_index_in_threadgroup]],
1125 uint simd_lid [[thread_index_in_simdgroup]]) {
1128 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1129 constexpr int BN_padded = (BN + 16 /
sizeof(T));
1131 threadgroup T Xs[BM * BK_padded];
1132 threadgroup T Ws[BK * BN_padded];
1135 x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
1138template <
typename T,
int group_size,
int bits>
1140 const device uint32_t* w [[buffer(0)]],
1141 const device T* scales [[buffer(1)]],
1142 const device T* biases [[buffer(2)]],
1143 const device T* x [[buffer(3)]],
1144 const device uint32_t* lhs_indices [[buffer(4)]],
1145 const device uint32_t* rhs_indices [[buffer(5)]],
1146 device T* y [[buffer(6)]],
1147 const constant
int& in_vec_size [[buffer(7)]],
1148 const constant
int& out_vec_size [[buffer(8)]],
1149 const constant
int& batch_ndims [[buffer(9)]],
1150 const constant
int* batch_shape [[buffer(10)]],
1151 const constant
size_t* lhs_strides [[buffer(11)]],
1152 const constant
size_t* rhs_strides [[buffer(12)]],
1153 const constant
int& x_batch_ndims [[buffer(13)]],
1154 const constant
int* x_shape [[buffer(14)]],
1155 const constant
size_t* x_strides [[buffer(15)]],
1156 const constant
int& w_batch_ndims [[buffer(16)]],
1157 const constant
int* w_shape [[buffer(17)]],
1158 const constant
size_t* w_strides [[buffer(18)]],
1159 const constant
size_t* s_strides [[buffer(19)]],
1160 const constant
size_t* b_strides [[buffer(20)]],
1161 uint3 tid [[threadgroup_position_in_grid]],
1162 uint simd_gid [[simdgroup_index_in_threadgroup]],
1163 uint simd_lid [[thread_index_in_simdgroup]]) {
1199template <
typename T,
int group_size,
int bits>
1201 const device uint32_t* w [[buffer(0)]],
1202 const device T* scales [[buffer(1)]],
1203 const device T* biases [[buffer(2)]],
1204 const device T* x [[buffer(3)]],
1205 const device uint32_t* lhs_indices [[buffer(4)]],
1206 const device uint32_t* rhs_indices [[buffer(5)]],
1207 device T* y [[buffer(6)]],
1208 const constant
int& in_vec_size [[buffer(7)]],
1209 const constant
int& out_vec_size [[buffer(8)]],
1210 const constant
int& batch_ndims [[buffer(9)]],
1211 const constant
int* batch_shape [[buffer(10)]],
1212 const constant
size_t* lhs_strides [[buffer(11)]],
1213 const constant
size_t* rhs_strides [[buffer(12)]],
1214 const constant
int& x_batch_ndims [[buffer(13)]],
1215 const constant
int* x_shape [[buffer(14)]],
1216 const constant
size_t* x_strides [[buffer(15)]],
1217 const constant
int& w_batch_ndims [[buffer(16)]],
1218 const constant
int* w_shape [[buffer(17)]],
1219 const constant
size_t* w_strides [[buffer(18)]],
1220 const constant
size_t* s_strides [[buffer(19)]],
1221 const constant
size_t* b_strides [[buffer(20)]],
1222 uint3 tid [[threadgroup_position_in_grid]],
1223 uint simd_gid [[simdgroup_index_in_threadgroup]],
1224 uint simd_lid [[thread_index_in_simdgroup]]) {
1260template <
typename T,
int group_size,
int bits>
1262 const device T* x [[buffer(0)]],
1263 const device uint32_t* w [[buffer(1)]],
1264 const device T* scales [[buffer(2)]],
1265 const device T* biases [[buffer(3)]],
1266 const device uint32_t* lhs_indices [[buffer(4)]],
1267 const device uint32_t* rhs_indices [[buffer(5)]],
1268 device T* y [[buffer(6)]],
1269 const constant
int& in_vec_size [[buffer(7)]],
1270 const constant
int& out_vec_size [[buffer(8)]],
1271 const constant
int& batch_ndims [[buffer(9)]],
1272 const constant
int* batch_shape [[buffer(10)]],
1273 const constant
size_t* lhs_strides [[buffer(11)]],
1274 const constant
size_t* rhs_strides [[buffer(12)]],
1275 const constant
int& x_batch_ndims [[buffer(13)]],
1276 const constant
int* x_shape [[buffer(14)]],
1277 const constant
size_t* x_strides [[buffer(15)]],
1278 const constant
int& w_batch_ndims [[buffer(16)]],
1279 const constant
int* w_shape [[buffer(17)]],
1280 const constant
size_t* w_strides [[buffer(18)]],
1281 const constant
size_t* s_strides [[buffer(19)]],
1282 const constant
size_t* b_strides [[buffer(20)]],
1283 uint3 tid [[threadgroup_position_in_grid]],
1284 uint simd_gid [[simdgroup_index_in_threadgroup]],
1285 uint simd_lid [[thread_index_in_simdgroup]]) {
1323 const int group_size,
1325 const bool aligned_N,
1330 const device T* x [[buffer(0)]],
1331 const device uint32_t* w [[buffer(1)]],
1332 const device T* scales [[buffer(2)]],
1333 const device T* biases [[buffer(3)]],
1334 const device uint32_t* lhs_indices [[buffer(4)]],
1335 const device uint32_t* rhs_indices [[buffer(5)]],
1336 device T* y [[buffer(6)]],
1337 const constant
int& M [[buffer(7)]],
1338 const constant
int& N [[buffer(8)]],
1339 const constant
int& K [[buffer(9)]],
1340 const constant
int& batch_ndims [[buffer(10)]],
1341 const constant
int* batch_shape [[buffer(11)]],
1342 const constant
size_t* lhs_strides [[buffer(12)]],
1343 const constant
size_t* rhs_strides [[buffer(13)]],
1344 const constant
int& x_batch_ndims [[buffer(14)]],
1345 const constant
int* x_shape [[buffer(15)]],
1346 const constant
size_t* x_strides [[buffer(16)]],
1347 const constant
int& w_batch_ndims [[buffer(17)]],
1348 const constant
int* w_shape [[buffer(18)]],
1349 const constant
size_t* w_strides [[buffer(19)]],
1350 const constant
size_t* s_strides [[buffer(20)]],
1351 const constant
size_t* b_strides [[buffer(21)]],
1352 uint3 tid [[threadgroup_position_in_grid]],
1353 uint lid [[thread_index_in_threadgroup]],
1354 uint simd_gid [[simdgroup_index_in_threadgroup]],
1355 uint simd_lid [[thread_index_in_simdgroup]]) {
1358 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1360 threadgroup T Xs[BM * BK_padded];
1361 threadgroup T Ws[BN * BK_padded];
1386 x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
1391 const int group_size,
1397 const device T* x [[buffer(0)]],
1398 const device uint32_t* w [[buffer(1)]],
1399 const device T* scales [[buffer(2)]],
1400 const device T* biases [[buffer(3)]],
1401 const device uint32_t* lhs_indices [[buffer(4)]],
1402 const device uint32_t* rhs_indices [[buffer(5)]],
1403 device T* y [[buffer(6)]],
1404 const constant
int& M [[buffer(7)]],
1405 const constant
int& N [[buffer(8)]],
1406 const constant
int& K [[buffer(9)]],
1407 const constant
int& batch_ndims [[buffer(10)]],
1408 const constant
int* batch_shape [[buffer(11)]],
1409 const constant
size_t* lhs_strides [[buffer(12)]],
1410 const constant
size_t* rhs_strides [[buffer(13)]],
1411 const constant
int& x_batch_ndims [[buffer(14)]],
1412 const constant
int* x_shape [[buffer(15)]],
1413 const constant
size_t* x_strides [[buffer(16)]],
1414 const constant
int& w_batch_ndims [[buffer(17)]],
1415 const constant
int* w_shape [[buffer(18)]],
1416 const constant
size_t* w_strides [[buffer(19)]],
1417 const constant
size_t* s_strides [[buffer(20)]],
1418 const constant
size_t* b_strides [[buffer(21)]],
1419 uint3 tid [[threadgroup_position_in_grid]],
1420 uint lid [[thread_index_in_threadgroup]],
1421 uint simd_gid [[simdgroup_index_in_threadgroup]],
1422 uint simd_lid [[thread_index_in_simdgroup]]) {
1425 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1426 constexpr int BN_padded = (BN + 16 /
sizeof(T));
1428 threadgroup T Xs[BM * BK_padded];
1429 threadgroup T Ws[BK * BN_padded];
1454 x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
1457template <
typename T, const
int group_size, const
int bits>
1459 const device T* w [[buffer(0)]],
1460 device uint8_t* out [[buffer(1)]],
1461 device T* scales [[buffer(2)]],
1462 device T* biases [[buffer(3)]],
1463 uint2 index [[thread_position_in_grid]],
1464 uint2 grid_dim [[threads_per_grid]]) {
1465 constexpr T eps = T(1e-7);
1467 constexpr int uint8_bits = 8;
1468 constexpr T n_bins = (1 << bits) - 1;
1469 constexpr int packs_per_int = uint8_bits / bits;
1470 constexpr int values_per_reduce = group_size /
simd_size;
1471 constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
1472 constexpr int writes_per_pack =
1473 writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
1477 "Group size must be divisible by simd size.");
1479 size_t offset = index.x + grid_dim.x * size_t(index.y);
1480 size_t in_index = offset * values_per_reduce;
1481 size_t out_index = offset * writes_per_pack;
1483 T w_thread[values_per_reduce];
1487#pragma clang loop unroll(full)
1488 for (
int i = 0; i < values_per_reduce; i++) {
1489 T val = w[in_index + i];
1491 w_min =
min(w_min, val);
1492 w_max =
max(w_max, val);
1498 T scale =
max((w_max - w_min) / n_bins, eps);
1499 bool side =
abs(w_min) >
abs(w_max);
1500 scale = side ? scale : -scale;
1501 T edge = side ? w_min : w_max;
1502 T q0 =
round(edge / scale);
1503 bool at_zero = q0 == 0.0f;
1504 scale = at_zero ? scale : edge / q0;
1505 T bias = at_zero ? T(0) : edge;
1508 size_t gindex = in_index / group_size;
1509 if (in_index % group_size == 0) {
1510 scales[gindex] = scale;
1511 biases[gindex] = bias;
1515#pragma clang loop unroll(full)
1516 for (
int i = 0; i < values_per_reduce; i++) {
1517 uint8_t val =
min(
round((w_thread[i] - bias) / scale), n_bins);
1521 output += val << (bits * (i % packs_per_int));
1524 if (packs_per_int < values_per_reduce &&
1525 i % packs_per_int == packs_per_int - 1) {
1526 out[out_index + i / packs_per_int] = output;
1529#pragma clang loop unroll(full)
1530 for (
int j = 0; j < writes_per_reduce - 1; j++) {
1532 output += sval << (bits * (values_per_reduce + j + i));
1536 if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
1537 out[out_index / writes_per_reduce] = output;
1541template <
typename T, const
int group_size, const
int bits>
1543 const device T* w [[buffer(0)]],
1544 const device T* scales [[buffer(1)]],
1545 const device T* biases [[buffer(2)]],
1546 device uint8_t* out [[buffer(3)]],
1547 uint2 index [[thread_position_in_grid]],
1548 uint2 grid_dim [[threads_per_grid]]) {
1549 constexpr int uint8_bits = 8;
1550 constexpr int packs_per_int = uint8_bits / bits;
1551 constexpr T n_bins = (1 << bits) - 1;
1553 size_t offset = index.x + grid_dim.x * size_t(index.y);
1554 size_t in_index = offset * packs_per_int;
1555 size_t gindex = in_index / group_size;
1557 T scale = scales[gindex];
1558 T bias = biases[gindex];
1561#pragma clang loop unroll(full)
1562 for (
int i = 0; i < packs_per_int; i++) {
1563 uint8_t val =
min(
round((w[in_index + i] - bias) / scale), n_bins);
1567 output += val << (bits * i);
1570 out[offset] = output;
1573template <
typename T, const
int group_size, const
int bits>
1575 const device uint8_t* w [[buffer(0)]],
1576 const device T* scales [[buffer(1)]],
1577 const device T* biases [[buffer(2)]],
1578 device T* out [[buffer(3)]],
1579 uint2 index [[thread_position_in_grid]],
1580 uint2 grid_dim [[threads_per_grid]]) {
1581 constexpr int uint8_bits = 8;
1582 constexpr int packs_per_int = uint8_bits / bits;
1584 size_t offset = index.x + grid_dim.x * size_t(index.y);
1585 size_t oindex = offset * packs_per_int;
1586 size_t gindex = oindex / group_size;
1587 T scale = scales[gindex];
1588 T bias = biases[gindex];
1589 uint val = w[offset];
1591#pragma clang loop unroll(full)
1592 for (
int i = 0; i < packs_per_int; i++) {
1595 d = (val >> (bits * i)) & 0x03;
1596 }
else if (bits == 4) {
1597 d = (val >> (bits * i)) & 0x0f;
1598 }
else if (bits == 8) {
1601 out[oindex + i] = scale * d + bias;
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
#define MLX_MTL_CONST
Definition quantized.h:8
void bs_qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1200
U qdot_safe(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum, int N)
Definition quantized.h:141
void qvm(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1050
METAL_FUNC void adjust_matrix_offsets(const device T *&x, const device uint32_t *&w, const device T *&scales, const device T *&biases, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *&y, int output_stride, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid)
Definition quantized.h:946
void bs_qmm_t(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *y, const constant int &M, const constant int &N, const constant int &K, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1329
void bs_qvm(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1261
void affine_quantize(const device T *w, device uint8_t *out, device T *scales, device T *biases, uint2 index, uint2 grid_dim)
Definition quantized.h:1458
METAL_FUNC void qmm_n_impl(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &M, const constant int &N, const constant int &K, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:820
void affine_dequantize(const device uint8_t *w, const device T *scales, const device T *biases, device T *out, uint2 index, uint2 grid_dim)
Definition quantized.h:1574
static constant constexpr const int SIMD_SIZE
Definition quantized.h:10
void bs_qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1139
void affine_quantize_scales_biases(const device T *w, const device T *scales, const device T *biases, device uint8_t *out, uint2 index, uint2 grid_dim)
Definition quantized.h:1542
U load_vector(const device T *x, thread U *x_thread)
Definition quantized.h:13
METAL_FUNC void qmv_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:439
METAL_FUNC void qvm_impl(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:588
U load_vector_safe(const device T *x, thread U *x_thread, int N)
Definition quantized.h:51
U qdot(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum)
Definition quantized.h:98
void qmm_n(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, const constant int &M, const constant int &N, const constant int &K, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1113
METAL_FUNC void qmv_fast_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:375
METAL_FUNC void qmm_t_impl(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &M, const constant int &N, const constant int &K, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:699
void bs_qmm_n(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *y, const constant int &M, const constant int &N, const constant int &K, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1396
void qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1000
void qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1025
void qouter(const thread uint8_t *w, U x, U scale, U bias, thread U *result)
Definition quantized.h:186
void dequantize(const device uint8_t *w, U scale, U bias, threadgroup U *w_local)
Definition quantized.h:218
void qmm_t(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, const constant int &M, const constant int &N, const constant int &K, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1082
Definition quantized.h:261
const int group_stride
Definition quantized.h:281
static constant constexpr const short BCOLS_PACKED
Definition quantized.h:273
const device T * biases
Definition quantized.h:290
short group_step_cnt
Definition quantized.h:280
static constant constexpr const short group_steps
Definition quantized.h:276
const short thread_idx
Definition quantized.h:283
const device T * scales
Definition quantized.h:289
static constant constexpr const short n_reads
Definition quantized.h:274
void next()
Definition quantized.h:353
void load_safe(short2 src_tile_dim) const
Definition quantized.h:326
const int src_ld
Definition quantized.h:278
const short bi
Definition quantized.h:284
void load_unsafe() const
Definition quantized.h:313
static constant constexpr const short pack_factor
Definition quantized.h:272
threadgroup T * dst
Definition quantized.h:287
const int tile_stride
Definition quantized.h:279
const device uint32_t * src
Definition quantized.h:288
const short bj
Definition quantized.h:285
QuantizedBlockLoader(const device uint32_t *src_, const device T *scales_, const device T *biases_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition quantized.h:292