3#include <metal_simdgroup>
8#define MLX_MTL_CONST static constant constexpr const
13template <
typename T,
typename U,
int values_per_thread,
int bits>
16 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
17 "Template undefined for bits not in {2, 3, 4, 6, 8}");
22 for (
int i = 0; i < values_per_thread; i += 4) {
23 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
25 x_thread[i + 1] = x[i + 1] / 4.0f;
26 x_thread[i + 2] = x[i + 2] / 16.0f;
27 x_thread[i + 3] = x[i + 3] / 64.0f;
32 for (
int i = 0; i < values_per_thread; i += 8) {
33 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
36 x_thread[i + 1] = x[i + 1] / 8.0f;
37 x_thread[i + 2] = x[i + 2] / 64.0f;
38 x_thread[i + 3] = x[i + 3] / 2.0f;
39 x_thread[i + 4] = x[i + 4] / 16.0f;
40 x_thread[i + 5] = x[i + 5] / 128.0f;
41 x_thread[i + 6] = x[i + 6] / 4.0f;
42 x_thread[i + 7] = x[i + 7] / 32.0f;
47 for (
int i = 0; i < values_per_thread; i += 4) {
48 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
50 x_thread[i + 1] = x[i + 1] / 16.0f;
51 x_thread[i + 2] = x[i + 2] / 256.0f;
52 x_thread[i + 3] = x[i + 3] / 4096.0f;
57 for (
int i = 0; i < values_per_thread; i += 4) {
58 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
60 x_thread[i + 1] = x[i + 1] / 64.0f;
61 x_thread[i + 2] = x[i + 2] / 16.0f;
62 x_thread[i + 3] = x[i + 3] / 4.0f;
67 for (
int i = 0; i < values_per_thread; i++) {
76template <
typename T,
typename U,
int values_per_thread,
int bits>
79 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
80 "Template undefined for bits not in {2, 3, 4, 6, 8}");
85 for (
int i = 0; i < N; i += 4) {
86 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
88 x_thread[i + 1] = x[i + 1] / 4.0f;
89 x_thread[i + 2] = x[i + 2] / 16.0f;
90 x_thread[i + 3] = x[i + 3] / 64.0f;
95 for (
int i = 0; i < N; i += 8) {
96 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
100 x_thread[i + 1] = x[i + 1] / 8.0f;
101 x_thread[i + 2] = x[i + 2] / 64.0f;
102 x_thread[i + 3] = x[i + 3] / 2.0f;
103 x_thread[i + 4] = x[i + 4] / 16.0f;
104 x_thread[i + 5] = x[i + 5] / 128.0f;
105 x_thread[i + 6] = x[i + 6] / 4.0f;
106 x_thread[i + 7] = x[i + 7] / 32.0f;
110 else if (bits == 4) {
111 for (
int i = 0; i < N; i += 4) {
112 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
114 x_thread[i + 1] = x[i + 1] / 16.0f;
115 x_thread[i + 2] = x[i + 2] / 256.0f;
116 x_thread[i + 3] = x[i + 3] / 4096.0f;
120 else if (bits == 6) {
121 for (
int i = 0; i < N; i += 4) {
122 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
124 x_thread[i + 1] = x[i + 1] / 64.0f;
125 x_thread[i + 2] = x[i + 2] / 16.0f;
126 x_thread[i + 3] = x[i + 3] / 4.0f;
130 else if (bits == 8) {
131 for (
int i = 0; i < N; i++) {
137 for (
int i = N; i < values_per_thread; i++) {
144template <
typename U,
int values_per_thread,
int bits>
146 const device uint8_t* w,
147 const thread U* x_thread,
152 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
153 "Template undefined for bits not in {2, 3, 4, 6, 8}");
158 for (
int i = 0; i < (values_per_thread / 4); i++) {
160 (x_thread[4 * i] * (w[i] & 0x03) +
161 x_thread[4 * i + 1] * (w[i] & 0x0c) +
162 x_thread[4 * i + 2] * (w[i] & 0x30) +
163 x_thread[4 * i + 3] * (w[i] & 0xc0));
167 else if (bits == 3) {
168 for (
int i = 0; i < (values_per_thread / 8); i++) {
172 accum += (w[0] & 0x07) * x_thread[0];
173 accum += (w[0] & 0x38) * x_thread[1];
174 accum += (w[0] & 0xc0) * x_thread[2];
175 accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
177 accum += (w[1] & 0x0e) * x_thread[3];
178 accum += (w[1] & 0x70) * x_thread[4];
179 accum += (w[1] & 0x80) * x_thread[5];
180 accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
182 accum += (w[2] & 0x1c) * x_thread[6];
183 accum += (w[2] & 0xe0) * x_thread[7];
187 else if (bits == 4) {
188 const device uint16_t* ws = (
const device uint16_t*)w;
189 for (
int i = 0; i < (values_per_thread / 4); i++) {
191 (x_thread[4 * i] * (ws[i] & 0x000f) +
192 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
193 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
194 x_thread[4 * i + 3] * (ws[i] & 0xf000));
198 else if (bits == 6) {
199 for (
int i = 0; i < (values_per_thread / 4); i++) {
203 accum += (w[0] & 0x3f) * x_thread[0];
205 accum += (w[0] & 0xc0) * x_thread[1];
206 accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
208 accum += (w[1] & 0xf0) * x_thread[2];
209 accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
211 accum += (w[2] & 0xfc) * x_thread[3];
215 else if (bits == 8) {
216 for (
int i = 0; i < values_per_thread; i++) {
217 accum += x_thread[i] * w[i];
221 return scale * accum + sum * bias;
224template <
typename U,
int values_per_thread,
int bits>
226 const device uint8_t* w,
227 const thread U* x_thread,
233 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
234 "Template undefined for bits not in {2, 3, 4, 6, 8}");
239 for (
int i = 0; i < (N / 4); i++) {
241 (x_thread[4 * i] * (w[i] & 0x03) +
242 x_thread[4 * i + 1] * (w[i] & 0x0c) +
243 x_thread[4 * i + 2] * (w[i] & 0x30) +
244 x_thread[4 * i + 3] * (w[i] & 0xc0));
248 else if (bits == 3) {
249 for (
int i = 0; i < (N / 8); i++) {
253 accum += (w[0] & 0x07) * x_thread[0];
254 accum += (w[0] & 0x38) * x_thread[1];
255 accum += (w[0] & 0xc0) * x_thread[2];
256 accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
258 accum += (w[1] & 0x0e) * x_thread[3];
259 accum += (w[1] & 0x70) * x_thread[4];
260 accum += (w[1] & 0x80) * x_thread[5];
261 accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
263 accum += (w[2] & 0x1c) * x_thread[6];
264 accum += (w[2] & 0xe0) * x_thread[7];
268 else if (bits == 4) {
269 const device uint16_t* ws = (
const device uint16_t*)w;
270 for (
int i = 0; i < (N / 4); i++) {
272 (x_thread[4 * i] * (ws[i] & 0x000f) +
273 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
274 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
275 x_thread[4 * i + 3] * (ws[i] & 0xf000));
279 else if (bits == 6) {
280 for (
int i = 0; i < (N / 4); i++) {
284 accum += (w[0] & 0x3f) * x_thread[0];
286 accum += (w[0] & 0xc0) * x_thread[1];
287 accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
289 accum += (w[1] & 0xf0) * x_thread[2];
290 accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
292 accum += (w[2] & 0xfc) * x_thread[3];
296 else if (bits == 8) {
297 for (
int i = 0; i < N; i++) {
298 accum += x_thread[i] * w[i];
302 return scale * accum + sum * bias;
305template <
typename U,
int values_per_thread,
int bits>
307qouter(
const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
309 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
310 "Template undefined for bits not in {2, 3, 4, 6, 8}");
313 U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
314 for (
int i = 0; i < (values_per_thread / 4); i++) {
315 result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
316 result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
317 result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
318 result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
322 else if (bits == 3) {
323 for (
int i = 0; i < (values_per_thread / 8); i++) {
324 uint8_t w0 = w[3 * i];
325 uint8_t w1 = w[3 * i + 1];
326 uint8_t w2 = w[3 * i + 2];
328 result[8 * i] += x * ((w0 & 0x7) * scale + bias);
329 result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);
331 x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);
332 result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);
333 result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);
335 x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);
336 result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);
337 result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);
341 else if (bits == 4) {
342 U s[2] = {scale, scale / 16.0f};
343 for (
int i = 0; i < (values_per_thread / 2); i++) {
344 result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
345 result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
348 }
else if (bits == 6) {
349 for (
int i = 0; i < (values_per_thread / 4); i++) {
350 uint8_t w0 = w[3 * i];
351 uint8_t w1 = w[3 * i + 1];
352 uint8_t w2 = w[3 * i + 2];
354 result[4 * i] += x * ((w0 & 0x3f) * scale + bias);
356 x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);
358 x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);
359 result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);
363 else if (bits == 8) {
364 for (
int i = 0; i < values_per_thread; i++) {
365 result[i] += x * (scale * w[i] + bias);
370template <
typename U,
int N,
int bits>
372dequantize(
const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
374 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
375 "Template undefined for bits not in {2, 3, 4, 6, 8}");
380 scale /
static_cast<U
>(4.0f),
381 scale /
static_cast<U
>(16.0f),
382 scale /
static_cast<U
>(64.0f)};
383 for (
int i = 0; i < (N / 4); i++) {
384 w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
385 w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
386 w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
387 w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
391 else if (bits == 3) {
392 for (
int i = 0; i < (N / 8); i++) {
396 w_local[0] = (w[0] & 0x7) * scale + bias;
397 w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;
398 w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
399 w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;
400 w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;
401 w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
402 w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
403 w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
407 else if (bits == 4) {
408 U s[2] = {scale, scale /
static_cast<U
>(16.0f)};
409 for (
int i = 0; i < (N / 2); i++) {
410 w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
411 w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
415 else if (bits == 6) {
416 for (
int i = 0; i < (N / 4); i++) {
420 w_local[0] = (w[0] & 0x3f) * scale + bias;
421 w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
422 w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
423 w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
427 else if (bits == 8) {
428 for (
int i = 0; i < N; i++) {
429 w_local[i] = scale * w[i] + bias;
446 "The group size should be larger than the columns");
448 group_size % BCOLS == 0,
449 "The group size should be divisible by the columns");
451 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
452 "Template undefined for bits not in {2, 3, 4, 6, 8}");
471 const device uint8_t*
src;
476 const device uint8_t* src_,
477 const device T* scales_,
478 const device T* biases_,
481 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
482 ushort simd_lane_id [[thread_index_in_simdgroup]])
489 thread_idx(simd_group_id * 32 + simd_lane_id),
505 for (
int i = 0; i <
n_reads; i++) {
516 if (reduction_dim == 1 &&
bi >= src_tile_dim.y) {
523 if (reduction_dim == 0 &&
bi >= src_tile_dim.x) {
532 for (
int i = 0; i <
n_reads; i++) {
543 if (reduction_dim == 1) {
562template <
typename T,
int group_size,
int bits,
int D>
564 const device uint32_t* w,
565 const device T* scales,
566 const device T* biases,
569 constant
int& in_vec_size,
570 const constant
int& out_vec_size,
571 uint3 tid [[threadgroup_position_in_grid]],
572 uint quad_gid [[quadgroup_index_in_threadgroup]],
573 uint quad_lid [[thread_index_in_quadgroup]]) {
575 constexpr int pack_factor = 32 / bits;
576 constexpr int values_per_thread = D /
QUAD_SIZE;
577 constexpr int packs_per_thread = values_per_thread / pack_factor;
578 constexpr int scale_step_per_thread = group_size / values_per_thread;
579 constexpr int results_per_quadgroup = 8;
583 thread U x_thread[values_per_thread];
584 thread U result[results_per_quadgroup] = {0};
587 const int in_vec_size_w = in_vec_size / pack_factor;
588 const int in_vec_size_g = in_vec_size / group_size;
589 const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
591 w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
592 scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
593 biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
594 x += tid.y * in_vec_size + quad_lid * values_per_thread;
595 y += tid.y * out_vec_size + out_row;
599 for (
int row = 0; row < results_per_quadgroup; row++) {
600 auto wl = (
const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
601 const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
602 const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
606 if (row * quads_per_simd + out_row < out_vec_size) {
611 for (
int row = 0; row < results_per_quadgroup; row++) {
612 result[row] = quad_sum(result[row]);
613 if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
614 y[row * quads_per_simd] =
static_cast<T
>(result[row]);
619template <
typename T,
int group_size,
int bits>
621 const device uint32_t* w,
622 const device T* scales,
623 const device T* biases,
626 const constant
int& in_vec_size,
627 const constant
int& out_vec_size,
628 uint3 tid [[threadgroup_position_in_grid]],
629 uint simd_gid [[simdgroup_index_in_threadgroup]],
630 uint simd_lid [[thread_index_in_simdgroup]]) {
631 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
632 constexpr int packs_per_thread = bits == 2 ? 1 : 2;
633 constexpr int num_simdgroups = 2;
634 constexpr int results_per_simdgroup = 4;
635 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
636 constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
637 constexpr int values_per_thread = pack_factor * packs_per_thread;
638 constexpr int block_size = values_per_thread *
SIMD_SIZE;
639 constexpr int scale_step_per_thread = group_size / values_per_thread;
641 const device uint8_t* ws = (
const device uint8_t*)w;
645 thread U x_thread[values_per_thread];
646 thread U result[results_per_simdgroup] = {0};
649 const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
650 const int in_vec_size_g = in_vec_size / group_size;
651 const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
652 simd_gid * results_per_simdgroup;
654 ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
655 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
656 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
657 x += tid.x * in_vec_size + simd_lid * values_per_thread;
658 y += tid.x * out_vec_size + out_row;
660 for (
int k = 0; k < in_vec_size; k += block_size) {
663 for (
int row = 0; row < results_per_simdgroup; row++) {
664 auto wl = (
const device uint8_t*)(ws + row * in_vec_size_w);
665 const device T* sl = scales + row * in_vec_size_g;
666 const device T* bl = biases + row * in_vec_size_g;
673 ws += block_size * bytes_per_pack / pack_factor;
674 scales += block_size / group_size;
675 biases += block_size / group_size;
679 for (
int row = 0; row < results_per_simdgroup; row++) {
680 result[row] =
simd_sum(result[row]);
682 y[row] =
static_cast<T
>(result[row]);
687template <
typename T,
int group_size,
int bits>
689 const device uint32_t* w,
690 const device T* scales,
691 const device T* biases,
694 const constant
int& in_vec_size,
695 const constant
int& out_vec_size,
696 uint3 tid [[threadgroup_position_in_grid]],
697 uint simd_gid [[simdgroup_index_in_threadgroup]],
698 uint simd_lid [[thread_index_in_simdgroup]]) {
699 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
700 constexpr int num_simdgroups = 2;
701 constexpr int results_per_simdgroup = 4;
702 constexpr int packs_per_thread = 1;
703 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
704 constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
705 constexpr int values_per_thread = pack_factor * packs_per_thread;
706 constexpr int block_size = values_per_thread *
SIMD_SIZE;
707 constexpr int scale_step_per_thread = group_size / values_per_thread;
709 const device uint8_t* ws = (
const device uint8_t*)w;
713 thread U x_thread[values_per_thread];
714 thread U result[results_per_simdgroup] = {0};
717 const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
718 const int in_vec_size_g = in_vec_size / group_size;
719 const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
720 simd_gid * results_per_simdgroup;
721 const int used_out_row =
min(out_vec_size - results_per_simdgroup, out_row);
723 if (out_row >= out_vec_size) {
729 if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
731 out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
732 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
733 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
734 x += tid.x * in_vec_size + simd_lid * values_per_thread;
735 y += tid.x * out_vec_size + out_row;
738 for (; k < in_vec_size - block_size; k += block_size) {
741 for (
int row = 0; out_row + row < out_vec_size; row++) {
742 auto wl = (
const device uint8_t*)(ws + row * in_vec_size_w);
743 const device T* sl = scales + row * in_vec_size_g;
744 const device T* bl = biases + row * in_vec_size_g;
752 ws += block_size * bytes_per_pack / pack_factor;
753 scales += block_size / group_size;
754 biases += block_size / group_size;
757 const int remaining = clamp(
758 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
763 x, x_thread, remaining);
765 for (
int row = 0; out_row + row < out_vec_size; row++) {
766 auto wl = (
const device uint8_t*)(ws + row * in_vec_size_w);
767 const device T* sl = scales + row * in_vec_size_g;
768 const device T* bl = biases + row * in_vec_size_g;
777 for (
int row = 0; out_row + row < out_vec_size; row++) {
778 result[row] =
simd_sum(result[row]);
780 y[row] =
static_cast<T
>(result[row]);
787 ws += used_out_row * in_vec_size_w +
788 simd_lid * packs_per_thread * bytes_per_pack;
789 scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
790 biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
791 x += tid.x * in_vec_size + simd_lid * values_per_thread;
792 y += tid.x * out_vec_size + used_out_row;
795 for (; k < in_vec_size - block_size; k += block_size) {
798 for (
int row = 0; row < results_per_simdgroup; row++) {
799 auto wl = (
const device uint8_t*)(ws + row * in_vec_size_w);
800 const device T* sl = scales + row * in_vec_size_g;
801 const device T* bl = biases + row * in_vec_size_g;
809 ws += block_size * bytes_per_pack / pack_factor;
810 scales += block_size / group_size;
811 biases += block_size / group_size;
814 const int remaining = clamp(
815 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
820 x, x_thread, remaining);
822 for (
int row = 0; row < results_per_simdgroup; row++) {
823 auto wl = (
const device uint8_t*)(ws + row * in_vec_size_w);
824 const device T* sl = scales + row * in_vec_size_g;
825 const device T* bl = biases + row * in_vec_size_g;
830 wl, x_thread, s, b, sum, remaining);
833 for (
int row = 0; row < results_per_simdgroup; row++) {
834 result[row] =
simd_sum(result[row]);
836 y[row] =
static_cast<T
>(result[row]);
842template <
typename T, const
int group_size, const
int bits>
844 const device uint32_t* w,
845 const device T* scales,
846 const device T* biases,
849 const int in_vec_size,
850 const int out_vec_size,
851 uint3 tid [[threadgroup_position_in_grid]],
852 uint simd_gid [[simdgroup_index_in_threadgroup]],
853 uint simd_lid [[thread_index_in_simdgroup]]) {
854 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
855 constexpr int num_simdgroups = 2;
856 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
857 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
858 constexpr int tn = 32 / pack_factor;
863 const device W_T* ws = (
const device W_T*)w;
867 W_T wi[tn * bytes_per_pack];
870 thread vec_w w_local;
871 thread U result[tn * pack_factor] = {0};
874 thread U x_local = 0;
877 const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
878 const int out_vec_size_g = out_vec_size / group_size;
879 int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid);
880 ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;
881 scales += out_col / group_size + simd_lid * out_vec_size_g;
882 biases += out_col / group_size + simd_lid * out_vec_size_g;
883 x += tid.x * in_vec_size + simd_lid;
884 y += tid.x * out_vec_size + out_col;
886 if (out_col >= out_vec_size) {
891 int remaining = in_vec_size % block_size;
892 if (remaining == 0) {
893 for (
int i = 0; i < in_vec_size; i += block_size) {
897 w_local = *((device vec_w*)ws);
899 (thread uint8_t*)&w_local, x_local, scale, bias, result);
902 scales += block_size * out_vec_size_g;
903 biases += block_size * out_vec_size_g;
904 ws += block_size * out_vec_size_w;
907 for (
int i = block_size; i < in_vec_size; i += block_size) {
911 w_local = *((device vec_w*)ws);
914 (thread uint8_t*)&w_local, x_local, scale, bias, result);
917 scales += block_size * out_vec_size_g;
918 biases += block_size * out_vec_size_g;
919 ws += block_size * out_vec_size_w;
921 if (
static_cast<int>(simd_lid) < remaining) {
925 w_local = *((device vec_w*)ws);
932 (thread uint8_t*)&w_local, x_local, scale, bias, result);
936#pragma clang loop unroll(full)
937 for (
int k = 0; k < tn * pack_factor; k++) {
943#pragma clang loop unroll(full)
944 for (
int k = 0; k < tn * pack_factor; k++) {
945 y[k] =
static_cast<T
>(result[k]);
952 const int group_size,
954 const bool aligned_N,
959 const device uint32_t* w,
960 const device T* scales,
961 const device T* biases,
966 const constant
int& K,
967 const constant
int& N,
968 const constant
int& M,
969 uint3 tid [[threadgroup_position_in_grid]],
970 uint lid [[thread_index_in_threadgroup]],
971 uint simd_gid [[simdgroup_index_in_threadgroup]],
972 uint simd_lid [[thread_index_in_simdgroup]]) {
973 static_assert(BK >=
SIMD_SIZE,
"BK should be larger than SIMD_SIZE");
974 static_assert(BK %
SIMD_SIZE == 0,
"BK should be divisible by SIMD_SIZE");
978 constexpr int WM = 2;
979 constexpr int WN = 2;
980 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
981 constexpr int BK_padded = (BK + 16 /
sizeof(T));
982 constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
985 using mma_t = mlx::steel::
986 BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
1000 const int K_w = K * bytes_per_pack / pack_factor;
1001 const int K_g = K / group_size;
1002 const int y_row = tid.y * BM;
1003 const int y_col = tid.x * BN;
1005 auto wl = (
const device uint8_t*)w;
1009 scales += y_col * K_g;
1010 biases += y_col * K_g;
1011 y += y_row * N + y_col;
1014 const short num_els =
min(BM, M - y_row);
1015 const short num_outs =
min(BN, N - y_col);
1016 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
1017 loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
1018 mma_t mma_op(simd_gid, simd_lid);
1021 if (!aligned_N && num_outs < BN) {
1022 for (
int k = 0; k < K; k += BK) {
1023 threadgroup_barrier(mem_flags::mem_threadgroup);
1024 loader_x.load_safe(short2(BK, num_els));
1025 loader_w.load_safe(short2(BK, num_outs));
1026 threadgroup_barrier(mem_flags::mem_threadgroup);
1032 for (
int k = 0; k < K; k += BK) {
1033 threadgroup_barrier(mem_flags::mem_threadgroup);
1034 loader_x.load_safe(short2(BK, num_els));
1035 loader_w.load_unsafe();
1036 threadgroup_barrier(mem_flags::mem_threadgroup);
1043 if (!aligned_N && num_outs < BN) {
1044 for (
int k = 0; k < K; k += BK) {
1045 threadgroup_barrier(mem_flags::mem_threadgroup);
1046 loader_x.load_unsafe();
1047 loader_w.load_safe(short2(BK, num_outs));
1048 threadgroup_barrier(mem_flags::mem_threadgroup);
1054 for (
int k = 0; k < K; k += BK) {
1055 threadgroup_barrier(mem_flags::mem_threadgroup);
1056 loader_x.load_unsafe();
1057 loader_w.load_unsafe();
1058 threadgroup_barrier(mem_flags::mem_threadgroup);
1068 threadgroup_barrier(mem_flags::mem_threadgroup);
1069 if (num_els < BM || num_outs < BN) {
1070 mma_op.store_result_safe(y, N, short2(num_outs, num_els));
1072 mma_op.store_result(y, N);
1078 const int group_size,
1084 const device uint32_t* w,
1085 const device T* scales,
1086 const device T* biases,
1091 const constant
int& K,
1092 const constant
int& N,
1093 const constant
int& M,
1094 uint3 tid [[threadgroup_position_in_grid]],
1095 uint lid [[thread_index_in_threadgroup]],
1096 uint simd_gid [[simdgroup_index_in_threadgroup]],
1097 uint simd_lid [[thread_index_in_simdgroup]]) {
1098 static_assert(BK >=
SIMD_SIZE,
"BK should be larger than SIMD_SIZE");
1099 static_assert(BK %
SIMD_SIZE == 0,
"BK should be divisible by SIMD_SIZE");
1103 constexpr int WM = 2;
1104 constexpr int WN = 2;
1105 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
1106 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1107 constexpr int BN_padded = (BN + 16 /
sizeof(T));
1108 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
1109 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
1112 using mma_t = mlx::steel::
1113 BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
1114 using loader_x_t = mlx::steel::
1115 BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
1126 auto wl = (
const device uint8_t*)w;
1129 const int y_row = tid.y * BM;
1130 const int y_col = tid.x * BN;
1132 wl += y_col * bytes_per_pack / pack_factor;
1133 scales += y_col / group_size;
1134 biases += y_col / group_size;
1135 y += y_row * N + y_col;
1138 const short num_els =
min(BM, M - y_row);
1139 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
1140 loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid);
1141 mma_t mma_op(simd_gid, simd_lid);
1144 if ((K % BK) != 0) {
1145 const int k_blocks = K / BK;
1146 for (
int k = 0; k < k_blocks; k++) {
1147 threadgroup_barrier(mem_flags::mem_threadgroup);
1148 loader_x.load_safe(short2(BK, num_els));
1149 loader_w.load_unsafe();
1150 threadgroup_barrier(mem_flags::mem_threadgroup);
1155 const short num_k = K - k_blocks * BK;
1156 threadgroup_barrier(mem_flags::mem_threadgroup);
1157 loader_x.load_safe(short2(num_k, num_els));
1158 loader_w.load_safe(short2(BN, num_k));
1159 threadgroup_barrier(mem_flags::mem_threadgroup);
1162 for (
int k = 0; k < K; k += BK) {
1163 threadgroup_barrier(mem_flags::mem_threadgroup);
1164 loader_x.load_safe(short2(BK, num_els));
1165 loader_w.load_unsafe();
1166 threadgroup_barrier(mem_flags::mem_threadgroup);
1173 if ((K % BK) != 0) {
1174 const int k_blocks = K / BK;
1175 for (
int k = 0; k < k_blocks; k++) {
1176 threadgroup_barrier(mem_flags::mem_threadgroup);
1177 loader_x.load_unsafe();
1178 loader_w.load_unsafe();
1179 threadgroup_barrier(mem_flags::mem_threadgroup);
1184 const short num_k = K - k_blocks * BK;
1185 threadgroup_barrier(mem_flags::mem_threadgroup);
1186 loader_x.load_safe(short2(num_k, BM));
1187 loader_w.load_safe(short2(BN, num_k));
1188 threadgroup_barrier(mem_flags::mem_threadgroup);
1191 for (
int k = 0; k < K; k += BK) {
1192 threadgroup_barrier(mem_flags::mem_threadgroup);
1193 loader_x.load_unsafe();
1194 loader_w.load_unsafe();
1195 threadgroup_barrier(mem_flags::mem_threadgroup);
1204 threadgroup_barrier(mem_flags::mem_threadgroup);
1206 mma_op.store_result_safe(y, N, short2(BN, num_els));
1208 mma_op.store_result(y, N);
1212template <
typename T>
1215 const device uint32_t*& w,
1216 const device T*& scales,
1217 const device T*& biases,
1220 const constant
int& x_batch_ndims,
1221 const constant
int* x_shape,
1222 const constant int64_t* x_strides,
1223 const constant
int& w_batch_ndims,
1224 const constant
int* w_shape,
1225 const constant int64_t* w_strides,
1226 const constant int64_t* s_strides,
1227 const constant int64_t* b_strides,
1228 uint3 tid [[threadgroup_position_in_grid]]) {
1230 uint32_t x_idx = tid.z;
1231 uint32_t w_idx = tid.z;
1232 if (x_batch_ndims == 1) {
1233 x += x_idx * x_strides[0];
1235 x +=
elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1237 if (w_batch_ndims == 1) {
1238 w += w_idx * w_strides[0];
1239 scales += w_idx * s_strides[0];
1240 biases += w_idx * b_strides[0];
1243 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1248 y += tid.z * output_stride;
1251template <
typename T>
1254 const device uint32_t*& w,
1255 const device T*& scales,
1256 const device T*& biases,
1257 const device uint32_t* lhs_indices,
1258 const device uint32_t* rhs_indices,
1261 const constant
int& batch_ndims,
1262 const constant
int* batch_shape,
1263 const constant int64_t* lhs_strides,
1264 const constant int64_t* rhs_strides,
1265 const constant
int& x_batch_ndims,
1266 const constant
int* x_shape,
1267 const constant int64_t* x_strides,
1268 const constant
int& w_batch_ndims,
1269 const constant
int* w_shape,
1270 const constant int64_t* w_strides,
1271 const constant int64_t* s_strides,
1272 const constant int64_t* b_strides,
1273 uint3 tid [[threadgroup_position_in_grid]]) {
1277 if (batch_ndims == 1) {
1278 x_idx = lhs_indices[tid.z * lhs_strides[0]];
1279 w_idx = rhs_indices[tid.z * rhs_strides[0]];
1282 tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
1283 x_idx = lhs_indices[idx.x];
1284 w_idx = rhs_indices[idx.y];
1286 if (x_batch_ndims == 1) {
1287 x += x_idx * x_strides[0];
1289 x +=
elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1291 if (w_batch_ndims == 1) {
1292 w += w_idx * w_strides[0];
1293 scales += w_idx * s_strides[0];
1294 biases += w_idx * b_strides[0];
1297 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1302 y += tid.z * output_stride;
1305template <
typename T,
int group_size,
int bits,
int D,
bool batched>
1307 const device uint32_t* w [[buffer(0)]],
1308 const device T* scales [[buffer(1)]],
1309 const device T* biases [[buffer(2)]],
1310 const device T* x [[buffer(3)]],
1311 device T* y [[buffer(4)]],
1312 const constant
int& in_vec_size [[buffer(5)]],
1313 const constant
int& out_vec_size [[buffer(6)]],
1314 const constant
int& x_batch_ndims [[buffer(7)]],
1315 const constant
int* x_shape [[buffer(8)]],
1316 const constant int64_t* x_strides [[buffer(9)]],
1317 const constant
int& w_batch_ndims [[buffer(10)]],
1318 const constant
int* w_shape [[buffer(11)]],
1319 const constant int64_t* w_strides [[buffer(12)]],
1320 const constant int64_t* s_strides [[buffer(13)]],
1321 const constant int64_t* b_strides [[buffer(14)]],
1322 uint3 tid [[threadgroup_position_in_grid]],
1323 uint quad_gid [[quadgroup_index_in_threadgroup]],
1324 uint quad_lid [[thread_index_in_quadgroup]]) {
1326 int M = x_shape[x_batch_ndims];
1357template <
typename T,
int group_size,
int bits,
bool batched>
1359 const device uint32_t* w [[buffer(0)]],
1360 const device T* scales [[buffer(1)]],
1361 const device T* biases [[buffer(2)]],
1362 const device T* x [[buffer(3)]],
1363 device T* y [[buffer(4)]],
1364 const constant
int& in_vec_size [[buffer(5)]],
1365 const constant
int& out_vec_size [[buffer(6)]],
1366 const constant
int& x_batch_ndims [[buffer(7)]],
1367 const constant
int* x_shape [[buffer(8)]],
1368 const constant int64_t* x_strides [[buffer(9)]],
1369 const constant
int& w_batch_ndims [[buffer(10)]],
1370 const constant
int* w_shape [[buffer(11)]],
1371 const constant int64_t* w_strides [[buffer(12)]],
1372 const constant int64_t* s_strides [[buffer(13)]],
1373 const constant int64_t* b_strides [[buffer(14)]],
1374 uint3 tid [[threadgroup_position_in_grid]],
1375 uint simd_gid [[simdgroup_index_in_threadgroup]],
1376 uint simd_lid [[thread_index_in_simdgroup]]) {
1378 int M = x_shape[x_batch_ndims];
1409template <
typename T, const
int group_size, const
int bits,
bool batched>
1411 const device uint32_t* w [[buffer(0)]],
1412 const device T* scales [[buffer(1)]],
1413 const device T* biases [[buffer(2)]],
1414 const device T* x [[buffer(3)]],
1415 device T* y [[buffer(4)]],
1416 const constant
int& in_vec_size [[buffer(5)]],
1417 const constant
int& out_vec_size [[buffer(6)]],
1418 const constant
int& x_batch_ndims [[buffer(7)]],
1419 const constant
int* x_shape [[buffer(8)]],
1420 const constant int64_t* x_strides [[buffer(9)]],
1421 const constant
int& w_batch_ndims [[buffer(10)]],
1422 const constant
int* w_shape [[buffer(11)]],
1423 const constant int64_t* w_strides [[buffer(12)]],
1424 const constant int64_t* s_strides [[buffer(13)]],
1425 const constant int64_t* b_strides [[buffer(14)]],
1426 uint3 tid [[threadgroup_position_in_grid]],
1427 uint simd_gid [[simdgroup_index_in_threadgroup]],
1428 uint simd_lid [[thread_index_in_simdgroup]]) {
1430 int M = x_shape[x_batch_ndims];
1461template <
typename T, const
int group_size, const
int bits,
bool batched>
1463 const device uint32_t* w [[buffer(0)]],
1464 const device T* scales [[buffer(1)]],
1465 const device T* biases [[buffer(2)]],
1466 const device T* x [[buffer(3)]],
1467 device T* y [[buffer(4)]],
1468 const constant
int& in_vec_size [[buffer(5)]],
1469 const constant
int& out_vec_size [[buffer(6)]],
1470 const constant
int& x_batch_ndims [[buffer(7)]],
1471 const constant
int* x_shape [[buffer(8)]],
1472 const constant int64_t* x_strides [[buffer(9)]],
1473 const constant
int& w_batch_ndims [[buffer(10)]],
1474 const constant
int* w_shape [[buffer(11)]],
1475 const constant int64_t* w_strides [[buffer(12)]],
1476 const constant int64_t* s_strides [[buffer(13)]],
1477 const constant int64_t* b_strides [[buffer(14)]],
1478 uint3 tid [[threadgroup_position_in_grid]],
1479 uint simd_gid [[simdgroup_index_in_threadgroup]],
1480 uint simd_lid [[thread_index_in_simdgroup]]) {
1482 int M = x_shape[x_batch_ndims];
1513template <
typename T, const
int group_size, const
int bits,
int split_k = 32>
1515 const device uint32_t* w [[buffer(0)]],
1516 const device T* scales [[buffer(1)]],
1517 const device T* biases [[buffer(2)]],
1518 const device T* x [[buffer(3)]],
1519 device T* y [[buffer(4)]],
1520 const constant
int& in_vec_size [[buffer(5)]],
1521 const constant
int& out_vec_size [[buffer(6)]],
1522 const constant
int& x_batch_ndims [[buffer(7)]],
1523 const constant
int* x_shape [[buffer(8)]],
1524 const constant int64_t* x_strides [[buffer(9)]],
1525 const constant
int& w_batch_ndims [[buffer(10)]],
1526 const constant
int* w_shape [[buffer(11)]],
1527 const constant int64_t* w_strides [[buffer(12)]],
1528 const constant int64_t* s_strides [[buffer(13)]],
1529 const constant int64_t* b_strides [[buffer(14)]],
1530 const constant
int& final_block_size [[buffer(15)]],
1531 uint3 tid [[threadgroup_position_in_grid]],
1532 uint simd_gid [[simdgroup_index_in_threadgroup]],
1533 uint simd_lid [[thread_index_in_simdgroup]]) {
1534 int M = x_shape[x_batch_ndims];
1553 int in_vec_size_adj =
1554 tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
1571 const int group_size,
1573 const bool aligned_N,
1579 const device uint32_t* w [[buffer(0)]],
1580 const device T* scales [[buffer(1)]],
1581 const device T* biases [[buffer(2)]],
1582 const device T* x [[buffer(3)]],
1583 device T* y [[buffer(4)]],
1584 const constant
int& K [[buffer(5)]],
1585 const constant
int& N [[buffer(6)]],
1586 const constant
int& M [[buffer(7)]],
1587 const constant
int& x_batch_ndims [[buffer(8)]],
1588 const constant
int* x_shape [[buffer(9)]],
1589 const constant int64_t* x_strides [[buffer(10)]],
1590 const constant
int& w_batch_ndims [[buffer(11)]],
1591 const constant
int* w_shape [[buffer(12)]],
1592 const constant int64_t* w_strides [[buffer(13)]],
1593 const constant int64_t* s_strides [[buffer(14)]],
1594 const constant int64_t* b_strides [[buffer(15)]],
1595 uint3 tid [[threadgroup_position_in_grid]],
1596 uint lid [[thread_index_in_threadgroup]],
1597 uint simd_gid [[simdgroup_index_in_threadgroup]],
1598 uint simd_lid [[thread_index_in_simdgroup]]) {
1601 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1603 threadgroup T Xs[BM * BK_padded];
1604 threadgroup T Ws[BN * BK_padded];
1625 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1630 const int group_size,
1637 const device uint32_t* w [[buffer(0)]],
1638 const device T* scales [[buffer(1)]],
1639 const device T* biases [[buffer(2)]],
1640 const device T* x [[buffer(3)]],
1641 device T* y [[buffer(4)]],
1642 const constant
int& K [[buffer(5)]],
1643 const constant
int& N [[buffer(6)]],
1644 const constant
int& M [[buffer(7)]],
1645 const constant
int& x_batch_ndims [[buffer(8)]],
1646 const constant
int* x_shape [[buffer(9)]],
1647 const constant int64_t* x_strides [[buffer(10)]],
1648 const constant
int& w_batch_ndims [[buffer(11)]],
1649 const constant
int* w_shape [[buffer(12)]],
1650 const constant int64_t* w_strides [[buffer(13)]],
1651 const constant int64_t* s_strides [[buffer(14)]],
1652 const constant int64_t* b_strides [[buffer(15)]],
1653 uint3 tid [[threadgroup_position_in_grid]],
1654 uint lid [[thread_index_in_threadgroup]],
1655 uint simd_gid [[simdgroup_index_in_threadgroup]],
1656 uint simd_lid [[thread_index_in_simdgroup]]) {
1659 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1660 constexpr int BN_padded = (BN + 16 /
sizeof(T));
1662 threadgroup T Xs[BM * BK_padded];
1663 threadgroup T Ws[BK * BN_padded];
1685 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1688template <
typename T,
int group_size,
int bits>
1690 const device uint32_t* w [[buffer(0)]],
1691 const device T* scales [[buffer(1)]],
1692 const device T* biases [[buffer(2)]],
1693 const device T* x [[buffer(3)]],
1694 device T* y [[buffer(4)]],
1695 const constant
int& in_vec_size [[buffer(5)]],
1696 const constant
int& out_vec_size [[buffer(6)]],
1697 const constant
int& x_batch_ndims [[buffer(7)]],
1698 const constant
int* x_shape [[buffer(8)]],
1699 const constant int64_t* x_strides [[buffer(9)]],
1700 const constant
int& w_batch_ndims [[buffer(10)]],
1701 const constant
int* w_shape [[buffer(11)]],
1702 const constant int64_t* w_strides [[buffer(12)]],
1703 const constant int64_t* s_strides [[buffer(13)]],
1704 const constant int64_t* b_strides [[buffer(14)]],
1705 const constant
int& batch_ndims [[buffer(15)]],
1706 const constant
int* batch_shape [[buffer(16)]],
1707 const device uint32_t* lhs_indices [[buffer(17)]],
1708 const device uint32_t* rhs_indices [[buffer(18)]],
1709 const constant int64_t* lhs_strides [[buffer(19)]],
1710 const constant int64_t* rhs_strides [[buffer(20)]],
1711 uint3 tid [[threadgroup_position_in_grid]],
1712 uint simd_gid [[simdgroup_index_in_threadgroup]],
1713 uint simd_lid [[thread_index_in_simdgroup]]) {
1714 int M = x_shape[x_batch_ndims];
1750template <
typename T,
int group_size,
int bits>
1752 const device uint32_t* w [[buffer(0)]],
1753 const device T* scales [[buffer(1)]],
1754 const device T* biases [[buffer(2)]],
1755 const device T* x [[buffer(3)]],
1756 device T* y [[buffer(4)]],
1757 const constant
int& in_vec_size [[buffer(5)]],
1758 const constant
int& out_vec_size [[buffer(6)]],
1759 const constant
int& x_batch_ndims [[buffer(7)]],
1760 const constant
int* x_shape [[buffer(8)]],
1761 const constant int64_t* x_strides [[buffer(9)]],
1762 const constant
int& w_batch_ndims [[buffer(10)]],
1763 const constant
int* w_shape [[buffer(11)]],
1764 const constant int64_t* w_strides [[buffer(12)]],
1765 const constant int64_t* s_strides [[buffer(13)]],
1766 const constant int64_t* b_strides [[buffer(14)]],
1767 const constant
int& batch_ndims [[buffer(15)]],
1768 const constant
int* batch_shape [[buffer(16)]],
1769 const device uint32_t* lhs_indices [[buffer(17)]],
1770 const device uint32_t* rhs_indices [[buffer(18)]],
1771 const constant int64_t* lhs_strides [[buffer(19)]],
1772 const constant int64_t* rhs_strides [[buffer(20)]],
1773 uint3 tid [[threadgroup_position_in_grid]],
1774 uint simd_gid [[simdgroup_index_in_threadgroup]],
1775 uint simd_lid [[thread_index_in_simdgroup]]) {
1776 int M = x_shape[x_batch_ndims];
1812template <
typename T,
int group_size,
int bits>
1814 const device uint32_t* w [[buffer(0)]],
1815 const device T* scales [[buffer(1)]],
1816 const device T* biases [[buffer(2)]],
1817 const device T* x [[buffer(3)]],
1818 device T* y [[buffer(4)]],
1819 const constant
int& in_vec_size [[buffer(5)]],
1820 const constant
int& out_vec_size [[buffer(6)]],
1821 const constant
int& x_batch_ndims [[buffer(7)]],
1822 const constant
int* x_shape [[buffer(8)]],
1823 const constant int64_t* x_strides [[buffer(9)]],
1824 const constant
int& w_batch_ndims [[buffer(10)]],
1825 const constant
int* w_shape [[buffer(11)]],
1826 const constant int64_t* w_strides [[buffer(12)]],
1827 const constant int64_t* s_strides [[buffer(13)]],
1828 const constant int64_t* b_strides [[buffer(14)]],
1829 const constant
int& batch_ndims [[buffer(15)]],
1830 const constant
int* batch_shape [[buffer(16)]],
1831 const device uint32_t* lhs_indices [[buffer(17)]],
1832 const device uint32_t* rhs_indices [[buffer(18)]],
1833 const constant int64_t* lhs_strides [[buffer(19)]],
1834 const constant int64_t* rhs_strides [[buffer(20)]],
1835 uint3 tid [[threadgroup_position_in_grid]],
1836 uint simd_gid [[simdgroup_index_in_threadgroup]],
1837 uint simd_lid [[thread_index_in_simdgroup]]) {
1838 int M = x_shape[x_batch_ndims];
1876 const int group_size,
1878 const bool aligned_N,
1883 const device uint32_t* w [[buffer(0)]],
1884 const device T* scales [[buffer(1)]],
1885 const device T* biases [[buffer(2)]],
1886 const device T* x [[buffer(3)]],
1887 device T* y [[buffer(4)]],
1888 const constant
int& K [[buffer(5)]],
1889 const constant
int& N [[buffer(6)]],
1890 const constant
int& M [[buffer(7)]],
1891 const constant
int& x_batch_ndims [[buffer(8)]],
1892 const constant
int* x_shape [[buffer(9)]],
1893 const constant int64_t* x_strides [[buffer(10)]],
1894 const constant
int& w_batch_ndims [[buffer(11)]],
1895 const constant
int* w_shape [[buffer(12)]],
1896 const constant int64_t* w_strides [[buffer(13)]],
1897 const constant int64_t* s_strides [[buffer(14)]],
1898 const constant int64_t* b_strides [[buffer(15)]],
1899 const constant
int& batch_ndims [[buffer(16)]],
1900 const constant
int* batch_shape [[buffer(17)]],
1901 const device uint32_t* lhs_indices [[buffer(18)]],
1902 const device uint32_t* rhs_indices [[buffer(19)]],
1903 const constant int64_t* lhs_strides [[buffer(20)]],
1904 const constant int64_t* rhs_strides [[buffer(21)]],
1905 uint3 tid [[threadgroup_position_in_grid]],
1906 uint lid [[thread_index_in_threadgroup]],
1907 uint simd_gid [[simdgroup_index_in_threadgroup]],
1908 uint simd_lid [[thread_index_in_simdgroup]]) {
1911 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1913 threadgroup T Xs[BM * BK_padded];
1914 threadgroup T Ws[BN * BK_padded];
1939 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1944 const int group_size,
1950 const device uint32_t* w [[buffer(0)]],
1951 const device T* scales [[buffer(1)]],
1952 const device T* biases [[buffer(2)]],
1953 const device T* x [[buffer(3)]],
1954 device T* y [[buffer(4)]],
1955 const constant
int& K [[buffer(5)]],
1956 const constant
int& N [[buffer(6)]],
1957 const constant
int& M [[buffer(7)]],
1958 const constant
int& x_batch_ndims [[buffer(8)]],
1959 const constant
int* x_shape [[buffer(9)]],
1960 const constant int64_t* x_strides [[buffer(10)]],
1961 const constant
int& w_batch_ndims [[buffer(11)]],
1962 const constant
int* w_shape [[buffer(12)]],
1963 const constant int64_t* w_strides [[buffer(13)]],
1964 const constant int64_t* s_strides [[buffer(14)]],
1965 const constant int64_t* b_strides [[buffer(15)]],
1966 const constant
int& batch_ndims [[buffer(16)]],
1967 const constant
int* batch_shape [[buffer(17)]],
1968 const device uint32_t* lhs_indices [[buffer(18)]],
1969 const device uint32_t* rhs_indices [[buffer(19)]],
1970 const constant int64_t* lhs_strides [[buffer(20)]],
1971 const constant int64_t* rhs_strides [[buffer(21)]],
1972 uint3 tid [[threadgroup_position_in_grid]],
1973 uint lid [[thread_index_in_threadgroup]],
1974 uint simd_gid [[simdgroup_index_in_threadgroup]],
1975 uint simd_lid [[thread_index_in_simdgroup]]) {
1978 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1979 constexpr int BN_padded = (BN + 16 /
sizeof(T));
1981 threadgroup T Xs[BM * BK_padded];
1982 threadgroup T Ws[BK * BN_padded];
2007 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
2010template <
typename T, const
int group_size, const
int bits>
2012 const device T* w [[buffer(0)]],
2013 device uint8_t* out [[buffer(1)]],
2014 device T* scales [[buffer(2)]],
2015 device T* biases [[buffer(3)]],
2016 uint2 index [[thread_position_in_grid]],
2017 uint2 grid_dim [[threads_per_grid]]) {
2018 constexpr T eps = T(1e-7);
2020 constexpr T n_bins = (1 << bits) - 1;
2021 constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
2022 constexpr int values_per_reduce = group_size /
simd_size;
2023 constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
2024 constexpr int writes_per_pack =
2025 writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
2026 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
2027 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
2031 "Group size must be divisible by simd size.");
2033 size_t offset = index.x + grid_dim.x * size_t(index.y);
2034 size_t in_index = offset * values_per_reduce;
2035 size_t out_index = power_of_2_bits
2036 ? offset * writes_per_pack
2037 : offset * bytes_per_pack / writes_per_reduce;
2039 T w_thread[values_per_reduce];
2043#pragma clang loop unroll(full)
2044 for (
int i = 0; i < values_per_reduce; i++) {
2045 T val = w[in_index + i];
2047 w_min =
min(w_min, val);
2048 w_max =
max(w_max, val);
2054 T scale =
max((w_max - w_min) / n_bins, eps);
2055 bool side =
abs(w_min) >
abs(w_max);
2056 scale = side ? scale : -scale;
2057 T edge = side ? w_min : w_max;
2058 T q0 =
round(edge / scale);
2059 bool at_zero = q0 == 0.0f;
2060 scale = at_zero ? scale : edge / q0;
2061 T bias = at_zero ? T(0) : edge;
2064 size_t gindex = in_index / group_size;
2065 if (in_index % group_size == 0) {
2066 scales[gindex] = scale;
2067 biases[gindex] = bias;
2071 uint32_t output = 0;
2073#pragma clang loop unroll(full)
2074 for (
int i = 0; i < values_per_reduce; i++) {
2075 uint8_t val =
min(
round((w_thread[i] - bias) / scale), n_bins);
2079 output += val << (bits * (i % packs_per_int));
2082 if (packs_per_int < values_per_reduce &&
2083 i % packs_per_int == packs_per_int - 1) {
2084 out[out_index + i / packs_per_int] = output;
2087#pragma clang loop unroll(full)
2088 for (
int j = 1; j < writes_per_reduce; j++) {
2090 output += sval << (bits * (j * values_per_reduce + i));
2094 if (bits == 3 || bits == 6) {
2095 if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) {
2096 out[out_index] = output & 0xff;
2097 out[out_index + 1] = (output & 0xff00) >> 8;
2098 out[out_index + 2] = (output & 0xff0000) >> 16;
2101 if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
2102 out[out_index / writes_per_reduce] = output;
2107template <
typename T, const
int group_size, const
int bits>
2109 const device uint8_t* w [[buffer(0)]],
2110 const device T* scales [[buffer(1)]],
2111 const device T* biases [[buffer(2)]],
2112 device T* out [[buffer(3)]],
2113 uint2 index [[thread_position_in_grid]],
2114 uint2 grid_dim [[threads_per_grid]]) {
2115 constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
2116 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
2117 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
2119 size_t offset = index.x + grid_dim.x * size_t(index.y);
2120 size_t oindex = offset * packs_per_int;
2121 size_t gindex = oindex / group_size;
2122 T scale = scales[gindex];
2123 T bias = biases[gindex];
2128 w += offset * bytes_per_pack;
2129 out[0] = (w[0] & 0x7) * scale + bias;
2130 out[1] = ((w[0] & 0x38) >> 3) * scale + bias;
2131 out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
2132 out[3] = ((w[1] & 0xe) >> 1) * scale + bias;
2133 out[4] = ((w[1] & 0x70) >> 4) * scale + bias;
2134 out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
2135 out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
2136 out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
2138 }
else if (bits == 6) {
2139 w += offset * bytes_per_pack;
2140 out[0] = (w[0] & 0x3f) * scale + bias;
2141 out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
2142 out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
2143 out[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
2145 uint val = w[offset];
2146#pragma clang loop unroll(full)
2147 for (
int i = 0; i < packs_per_int; i++) {
2150 d = (val >> (bits * i)) & 0x03;
2151 }
else if (bits == 4) {
2152 d = (val >> (bits * i)) & 0x0f;
2153 }
else if (bits == 8) {
2156 out[i] = scale * d + bias;
#define MLX_MTL_CONST
Definition gemv_masked.h:7
array bits(const Shape &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
#define MLX_MTL_CONST
Definition quantized.h:8
U qdot_safe(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum, int N)
Definition quantized.h:225
METAL_FUNC void qmm_n_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1083
METAL_FUNC void qvm_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const int in_vec_size, const int out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:843
void bs_qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant int64_t *lhs_strides, const constant int64_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1813
void bs_qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant int64_t *lhs_strides, const constant int64_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1949
void qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1358
void bs_qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant int64_t *lhs_strides, const constant int64_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1689
METAL_FUNC void adjust_matrix_offsets(const device T *&x, const device uint32_t *&w, const device T *&scales, const device T *&biases, device T *&y, int output_stride, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid)
Definition quantized.h:1213
void affine_quantize(const device T *w, device uint8_t *out, device T *scales, device T *biases, uint2 index, uint2 grid_dim)
Definition quantized.h:2011
void qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1462
void affine_dequantize(const device uint8_t *w, const device T *scales, const device T *biases, device T *out, uint2 index, uint2 grid_dim)
Definition quantized.h:2108
static constant constexpr const int SIMD_SIZE
Definition quantized.h:10
void bs_qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant int64_t *lhs_strides, const constant int64_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1882
void qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1636
static constant constexpr const int QUAD_SIZE
Definition quantized.h:11
void qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1410
void qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1578
U load_vector(const device T *x, thread U *x_thread)
Definition quantized.h:14
METAL_FUNC void qmv_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:688
void qmv_quad(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:1306
U load_vector_safe(const device T *x, thread U *x_thread, int N)
Definition quantized.h:77
void qvm_split_k(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &final_block_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1514
void bs_qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant int64_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant int64_t *w_strides, const constant int64_t *s_strides, const constant int64_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant int64_t *lhs_strides, const constant int64_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1751
U qdot(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum)
Definition quantized.h:145
METAL_FUNC void qmv_fast_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:620
METAL_FUNC void qmv_quad_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:563
void qouter(const thread uint8_t *w, U x, U scale, U bias, thread U *result)
Definition quantized.h:307
void dequantize(const device uint8_t *w, U scale, U bias, threadgroup U *w_local)
Definition quantized.h:372
METAL_FUNC void qmm_t_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:958
U type
Definition utils.h:417
static const constant U max
Definition utils.h:24
Definition quantized.h:443
const int group_stride
Definition quantized.h:464
static constant constexpr const short BCOLS_PACKED
Definition quantized.h:456
const device T * biases
Definition quantized.h:473
short group_step_cnt
Definition quantized.h:463
static constant constexpr const short group_steps
Definition quantized.h:459
const short thread_idx
Definition quantized.h:466
QuantizedBlockLoader(const device uint8_t *src_, const device T *scales_, const device T *biases_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition quantized.h:475
const device T * scales
Definition quantized.h:472
static constant constexpr const short n_reads
Definition quantized.h:457
void next()
Definition quantized.h:541
void load_safe(short2 src_tile_dim) const
Definition quantized.h:511
const int src_ld
Definition quantized.h:461
const short bi
Definition quantized.h:467
void load_unsafe() const
Definition quantized.h:498
static constant constexpr const short pack_factor
Definition quantized.h:454
threadgroup T * dst
Definition quantized.h:470
const device uint8_t * src
Definition quantized.h:471
const int tile_stride
Definition quantized.h:462
static constant constexpr const short bytes_per_pack
Definition quantized.h:455
const short bj
Definition quantized.h:468