3#include <metal_simdgroup>
8#define MLX_MTL_CONST static constant constexpr const
13template <
typename T,
typename U,
int values_per_thread,
int bits>
16 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
17 "Template undefined for bits not in {2, 3, 4, 6, 8}");
22 for (
int i = 0; i < values_per_thread; i += 4) {
23 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
25 x_thread[i + 1] = x[i + 1] / 4.0f;
26 x_thread[i + 2] = x[i + 2] / 16.0f;
27 x_thread[i + 3] = x[i + 3] / 64.0f;
32 for (
int i = 0; i < values_per_thread; i += 8) {
33 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
36 x_thread[i + 1] = x[i + 1] / 8.0f;
37 x_thread[i + 2] = x[i + 2] / 64.0f;
38 x_thread[i + 3] = x[i + 3] / 2.0f;
39 x_thread[i + 4] = x[i + 4] / 16.0f;
40 x_thread[i + 5] = x[i + 5] / 128.0f;
41 x_thread[i + 6] = x[i + 6] / 4.0f;
42 x_thread[i + 7] = x[i + 7] / 32.0f;
47 for (
int i = 0; i < values_per_thread; i += 4) {
48 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
50 x_thread[i + 1] = x[i + 1] / 16.0f;
51 x_thread[i + 2] = x[i + 2] / 256.0f;
52 x_thread[i + 3] = x[i + 3] / 4096.0f;
57 for (
int i = 0; i < values_per_thread; i += 4) {
58 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
60 x_thread[i + 1] = x[i + 1] / 64.0f;
61 x_thread[i + 2] = x[i + 2] / 16.0f;
62 x_thread[i + 3] = x[i + 3] / 4.0f;
67 for (
int i = 0; i < values_per_thread; i++) {
76template <
typename T,
typename U,
int values_per_thread,
int bits>
79 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
80 "Template undefined for bits not in {2, 3, 4, 6, 8}");
85 for (
int i = 0; i < N; i += 4) {
86 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
88 x_thread[i + 1] = x[i + 1] / 4.0f;
89 x_thread[i + 2] = x[i + 2] / 16.0f;
90 x_thread[i + 3] = x[i + 3] / 64.0f;
95 for (
int i = 0; i < N; i += 8) {
96 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
100 x_thread[i + 1] = x[i + 1] / 8.0f;
101 x_thread[i + 2] = x[i + 2] / 64.0f;
102 x_thread[i + 3] = x[i + 3] / 2.0f;
103 x_thread[i + 4] = x[i + 4] / 16.0f;
104 x_thread[i + 5] = x[i + 5] / 128.0f;
105 x_thread[i + 6] = x[i + 6] / 4.0f;
106 x_thread[i + 7] = x[i + 7] / 32.0f;
110 else if (bits == 4) {
111 for (
int i = 0; i < N; i += 4) {
112 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
114 x_thread[i + 1] = x[i + 1] / 16.0f;
115 x_thread[i + 2] = x[i + 2] / 256.0f;
116 x_thread[i + 3] = x[i + 3] / 4096.0f;
120 else if (bits == 6) {
121 for (
int i = 0; i < N; i += 4) {
122 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
124 x_thread[i + 1] = x[i + 1] / 64.0f;
125 x_thread[i + 2] = x[i + 2] / 16.0f;
126 x_thread[i + 3] = x[i + 3] / 4.0f;
130 else if (bits == 8) {
131 for (
int i = 0; i < N; i++) {
137 for (
int i = N; i < values_per_thread; i++) {
144template <
typename U,
int values_per_thread,
int bits>
146 const device uint8_t* w,
147 const thread U* x_thread,
152 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
153 "Template undefined for bits not in {2, 3, 4, 6, 8}");
158 for (
int i = 0; i < (values_per_thread / 4); i++) {
160 (x_thread[4 * i] * (w[i] & 0x03) +
161 x_thread[4 * i + 1] * (w[i] & 0x0c) +
162 x_thread[4 * i + 2] * (w[i] & 0x30) +
163 x_thread[4 * i + 3] * (w[i] & 0xc0));
167 else if (bits == 3) {
168 for (
int i = 0; i < (values_per_thread / 8); i++) {
172 accum += (w[0] & 0x07) * x_thread[0];
173 accum += (w[0] & 0x38) * x_thread[1];
174 accum += (w[0] & 0xc0) * x_thread[2];
175 accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
177 accum += (w[1] & 0x0e) * x_thread[3];
178 accum += (w[1] & 0x70) * x_thread[4];
179 accum += (w[1] & 0x80) * x_thread[5];
180 accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
182 accum += (w[2] & 0x1c) * x_thread[6];
183 accum += (w[2] & 0xe0) * x_thread[7];
187 else if (bits == 4) {
188 const device uint16_t* ws = (
const device uint16_t*)w;
189 for (
int i = 0; i < (values_per_thread / 4); i++) {
191 (x_thread[4 * i] * (ws[i] & 0x000f) +
192 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
193 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
194 x_thread[4 * i + 3] * (ws[i] & 0xf000));
198 else if (bits == 6) {
199 for (
int i = 0; i < (values_per_thread / 4); i++) {
203 accum += (w[0] & 0x3f) * x_thread[0];
205 accum += (w[0] & 0xc0) * x_thread[1];
206 accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
208 accum += (w[1] & 0xf0) * x_thread[2];
209 accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
211 accum += (w[2] & 0xfc) * x_thread[3];
215 else if (bits == 8) {
216 for (
int i = 0; i < values_per_thread; i++) {
217 accum += x_thread[i] * w[i];
221 return scale * accum + sum * bias;
224template <
typename U,
int values_per_thread,
int bits>
226 const device uint8_t* w,
227 const thread U* x_thread,
233 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
234 "Template undefined for bits not in {2, 3, 4, 6, 8}");
239 for (
int i = 0; i < (N / 4); i++) {
241 (x_thread[4 * i] * (w[i] & 0x03) +
242 x_thread[4 * i + 1] * (w[i] & 0x0c) +
243 x_thread[4 * i + 2] * (w[i] & 0x30) +
244 x_thread[4 * i + 3] * (w[i] & 0xc0));
248 else if (bits == 3) {
249 for (
int i = 0; i < (N / 8); i++) {
253 accum += (w[0] & 0x07) * x_thread[0];
254 accum += (w[0] & 0x38) * x_thread[1];
255 accum += (w[0] & 0xc0) * x_thread[2];
256 accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
258 accum += (w[1] & 0x0e) * x_thread[3];
259 accum += (w[1] & 0x70) * x_thread[4];
260 accum += (w[1] & 0x80) * x_thread[5];
261 accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
263 accum += (w[2] & 0x1c) * x_thread[6];
264 accum += (w[2] & 0xe0) * x_thread[7];
268 else if (bits == 4) {
269 const device uint16_t* ws = (
const device uint16_t*)w;
270 for (
int i = 0; i < (N / 4); i++) {
272 (x_thread[4 * i] * (ws[i] & 0x000f) +
273 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
274 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
275 x_thread[4 * i + 3] * (ws[i] & 0xf000));
279 else if (bits == 6) {
280 for (
int i = 0; i < (N / 4); i++) {
284 accum += (w[0] & 0x3f) * x_thread[0];
286 accum += (w[0] & 0xc0) * x_thread[1];
287 accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
289 accum += (w[1] & 0xf0) * x_thread[2];
290 accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
292 accum += (w[2] & 0xfc) * x_thread[3];
296 else if (bits == 8) {
297 for (
int i = 0; i < N; i++) {
298 accum += x_thread[i] * w[i];
302 return scale * accum + sum * bias;
305template <
typename U,
int values_per_thread,
int bits>
307qouter(
const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
309 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
310 "Template undefined for bits not in {2, 3, 4, 6, 8}");
313 U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
314 for (
int i = 0; i < (values_per_thread / 4); i++) {
315 result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
316 result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
317 result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
318 result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
322 else if (bits == 3) {
323 for (
int i = 0; i < (values_per_thread / 8); i++) {
324 uint8_t w0 = w[3 * i];
325 uint8_t w1 = w[3 * i + 1];
326 uint8_t w2 = w[3 * i + 2];
328 result[8 * i] += x * ((w0 & 0x7) * scale + bias);
329 result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);
331 x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);
332 result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);
333 result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);
335 x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);
336 result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);
337 result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);
341 else if (bits == 4) {
342 U s[2] = {scale, scale / 16.0f};
343 for (
int i = 0; i < (values_per_thread / 2); i++) {
344 result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
345 result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
348 }
else if (bits == 6) {
349 for (
int i = 0; i < (values_per_thread / 4); i++) {
350 uint8_t w0 = w[3 * i];
351 uint8_t w1 = w[3 * i + 1];
352 uint8_t w2 = w[3 * i + 2];
354 result[4 * i] += x * ((w0 & 0x3f) * scale + bias);
356 x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);
358 x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);
359 result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);
363 else if (bits == 8) {
364 for (
int i = 0; i < values_per_thread; i++) {
365 result[i] += x * (scale * w[i] + bias);
370template <
typename U,
int N,
int bits>
372dequantize(
const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
374 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
375 "Template undefined for bits not in {2, 3, 4, 6, 8}");
380 scale /
static_cast<U
>(4.0f),
381 scale /
static_cast<U
>(16.0f),
382 scale /
static_cast<U
>(64.0f)};
383 for (
int i = 0; i < (N / 4); i++) {
384 w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
385 w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
386 w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
387 w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
391 else if (bits == 3) {
392 for (
int i = 0; i < (N / 8); i++) {
396 w_local[0] = (w[0] & 0x7) * scale + bias;
397 w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;
398 w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
399 w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;
400 w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;
401 w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
402 w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
403 w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
407 else if (bits == 4) {
408 U s[2] = {scale, scale /
static_cast<U
>(16.0f)};
409 for (
int i = 0; i < (N / 2); i++) {
410 w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
411 w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
415 else if (bits == 6) {
416 for (
int i = 0; i < (N / 4); i++) {
420 w_local[0] = (w[0] & 0x3f) * scale + bias;
421 w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
422 w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
423 w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
427 else if (bits == 8) {
428 for (
int i = 0; i < N; i++) {
429 w_local[i] = scale * w[i] + bias;
446 "The group size should be larger than the columns");
448 group_size % BCOLS == 0,
449 "The group size should be divisible by the columns");
451 bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
452 "Template undefined for bits not in {2, 3, 4, 6, 8}");
471 const device uint8_t*
src;
476 const device uint8_t* src_,
477 const device T* scales_,
478 const device T* biases_,
481 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
482 ushort simd_lane_id [[thread_index_in_simdgroup]])
489 thread_idx(simd_group_id * 32 + simd_lane_id),
505 for (
int i = 0; i <
n_reads; i++) {
516 if (reduction_dim == 1 &&
bi >= src_tile_dim.y) {
523 if (reduction_dim == 0 &&
bi >= src_tile_dim.x) {
532 for (
int i = 0; i <
n_reads; i++) {
543 if (reduction_dim == 1) {
562template <
typename T,
int group_size,
int bits,
int D>
564 const device uint32_t* w,
565 const device T* scales,
566 const device T* biases,
569 constant
int& in_vec_size,
570 const constant
int& out_vec_size,
571 uint3 tid [[threadgroup_position_in_grid]],
572 uint quad_gid [[quadgroup_index_in_threadgroup]],
573 uint quad_lid [[thread_index_in_quadgroup]]) {
575 constexpr int pack_factor = 32 / bits;
576 constexpr int values_per_thread = D /
QUAD_SIZE;
577 constexpr int packs_per_thread = values_per_thread / pack_factor;
578 constexpr int scale_step_per_thread = group_size / values_per_thread;
579 constexpr int results_per_quadgroup = 8;
583 thread U x_thread[values_per_thread];
584 thread U result[results_per_quadgroup] = {0};
587 const int in_vec_size_w = in_vec_size / pack_factor;
588 const int in_vec_size_g = in_vec_size / group_size;
589 const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
591 w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
592 scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
593 biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
594 x += tid.y * in_vec_size + quad_lid * values_per_thread;
595 y += tid.y * out_vec_size + out_row;
599 for (
int row = 0; row < results_per_quadgroup; row++) {
600 auto wl = (
const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
601 const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
602 const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
606 if (row * quads_per_simd + out_row < out_vec_size) {
611 for (
int row = 0; row < results_per_quadgroup; row++) {
612 result[row] = quad_sum(result[row]);
613 if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
614 y[row * quads_per_simd] =
static_cast<T
>(result[row]);
619template <
typename T,
int group_size,
int bits>
621 const device uint32_t* w,
622 const device T* scales,
623 const device T* biases,
626 const constant
int& in_vec_size,
627 const constant
int& out_vec_size,
628 uint3 tid [[threadgroup_position_in_grid]],
629 uint simd_gid [[simdgroup_index_in_threadgroup]],
630 uint simd_lid [[thread_index_in_simdgroup]]) {
631 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
632 constexpr int packs_per_thread = bits == 2 ? 1 : 2;
633 constexpr int num_simdgroups = 2;
634 constexpr int results_per_simdgroup = 4;
635 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
636 constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
637 constexpr int values_per_thread = pack_factor * packs_per_thread;
638 constexpr int block_size = values_per_thread *
SIMD_SIZE;
639 constexpr int scale_step_per_thread = group_size / values_per_thread;
641 const device uint8_t* ws = (
const device uint8_t*)w;
645 thread U x_thread[values_per_thread];
646 thread U result[results_per_simdgroup] = {0};
649 const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
650 const int in_vec_size_g = in_vec_size / group_size;
651 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
652 simd_gid * results_per_simdgroup;
654 ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
655 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
656 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
657 x += tid.y * in_vec_size + simd_lid * values_per_thread;
658 y += tid.y * out_vec_size + out_row;
660 for (
int k = 0; k < in_vec_size; k += block_size) {
663 for (
int row = 0; row < results_per_simdgroup; row++) {
664 auto wl = (
const device uint8_t*)(ws + row * in_vec_size_w);
665 const device T* sl = scales + row * in_vec_size_g;
666 const device T* bl = biases + row * in_vec_size_g;
673 ws += block_size * bytes_per_pack / pack_factor;
674 scales += block_size / group_size;
675 biases += block_size / group_size;
679 for (
int row = 0; row < results_per_simdgroup; row++) {
680 result[row] =
simd_sum(result[row]);
682 y[row] =
static_cast<T
>(result[row]);
687template <
typename T,
int group_size,
int bits>
689 const device uint32_t* w,
690 const device T* scales,
691 const device T* biases,
694 const constant
int& in_vec_size,
695 const constant
int& out_vec_size,
696 uint3 tid [[threadgroup_position_in_grid]],
697 uint simd_gid [[simdgroup_index_in_threadgroup]],
698 uint simd_lid [[thread_index_in_simdgroup]]) {
699 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
700 constexpr int num_simdgroups = 2;
701 constexpr int results_per_simdgroup = 4;
702 constexpr int packs_per_thread = 1;
703 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
704 constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
705 constexpr int values_per_thread = pack_factor * packs_per_thread;
706 constexpr int block_size = values_per_thread *
SIMD_SIZE;
707 constexpr int scale_step_per_thread = group_size / values_per_thread;
709 const device uint8_t* ws = (
const device uint8_t*)w;
713 thread U x_thread[values_per_thread];
714 thread U result[results_per_simdgroup] = {0};
717 const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
718 const int in_vec_size_g = in_vec_size / group_size;
719 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
720 simd_gid * results_per_simdgroup;
721 const int used_out_row =
min(out_vec_size - results_per_simdgroup, out_row);
723 if (out_row >= out_vec_size) {
729 if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
731 out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
732 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
733 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
734 x += tid.y * in_vec_size + simd_lid * values_per_thread;
735 y += tid.y * out_vec_size + out_row;
738 for (; k < in_vec_size - block_size; k += block_size) {
741 for (
int row = 0; out_row + row < out_vec_size; row++) {
742 auto wl = (
const device uint8_t*)(ws + row * in_vec_size_w);
743 const device T* sl = scales + row * in_vec_size_g;
744 const device T* bl = biases + row * in_vec_size_g;
752 ws += block_size * bytes_per_pack / pack_factor;
753 scales += block_size / group_size;
754 biases += block_size / group_size;
757 const int remaining = clamp(
758 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
763 x, x_thread, remaining);
765 for (
int row = 0; out_row + row < out_vec_size; row++) {
766 auto wl = (
const device uint8_t*)(ws + row * in_vec_size_w);
767 const device T* sl = scales + row * in_vec_size_g;
768 const device T* bl = biases + row * in_vec_size_g;
777 for (
int row = 0; out_row + row < out_vec_size; row++) {
778 result[row] =
simd_sum(result[row]);
780 y[row] =
static_cast<T
>(result[row]);
787 ws += used_out_row * in_vec_size_w +
788 simd_lid * packs_per_thread * bytes_per_pack;
789 scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
790 biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
791 x += tid.y * in_vec_size + simd_lid * values_per_thread;
792 y += tid.y * out_vec_size + used_out_row;
795 for (; k < in_vec_size - block_size; k += block_size) {
798 for (
int row = 0; row < results_per_simdgroup; row++) {
799 auto wl = (
const device uint8_t*)(ws + row * in_vec_size_w);
800 const device T* sl = scales + row * in_vec_size_g;
801 const device T* bl = biases + row * in_vec_size_g;
809 ws += block_size * bytes_per_pack / pack_factor;
810 scales += block_size / group_size;
811 biases += block_size / group_size;
814 const int remaining = clamp(
815 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
820 x, x_thread, remaining);
822 for (
int row = 0; row < results_per_simdgroup; row++) {
823 auto wl = (
const device uint8_t*)(ws + row * in_vec_size_w);
824 const device T* sl = scales + row * in_vec_size_g;
825 const device T* bl = biases + row * in_vec_size_g;
830 wl, x_thread, s, b, sum, remaining);
833 for (
int row = 0; row < results_per_simdgroup; row++) {
834 result[row] =
simd_sum(result[row]);
836 y[row] =
static_cast<T
>(result[row]);
842template <
typename T, const
int group_size, const
int bits>
844 const device uint32_t* w,
845 const device T* scales,
846 const device T* biases,
849 const int in_vec_size,
850 const int out_vec_size,
851 uint3 tid [[threadgroup_position_in_grid]],
852 uint simd_gid [[simdgroup_index_in_threadgroup]],
853 uint simd_lid [[thread_index_in_simdgroup]]) {
854 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
855 constexpr int num_simdgroups = 2;
856 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
857 constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
858 constexpr int tn = 32 / pack_factor;
861 const device uint8_t* ws = (
const device uint8_t*)w;
865 uint8_t wi[tn * bytes_per_pack];
868 thread vec_w w_local;
869 thread U result[tn * pack_factor] = {0};
872 thread U x_local = 0;
875 const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
876 const int out_vec_size_g = out_vec_size / group_size;
877 int out_col = pack_factor * tn * (tid.x * num_simdgroups + simd_gid);
878 ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;
879 scales += out_col / group_size + simd_lid * out_vec_size_g;
880 biases += out_col / group_size + simd_lid * out_vec_size_g;
881 x += tid.y * in_vec_size + simd_lid;
882 y += tid.y * out_vec_size + out_col;
884 if (out_col >= out_vec_size) {
889 int remaining = in_vec_size % block_size;
890 if (remaining == 0) {
891 for (
int i = 0; i < in_vec_size; i += block_size) {
895 w_local = *((device vec_w*)ws);
897 (thread uint8_t*)&w_local, x_local, scale, bias, result);
900 scales += block_size * out_vec_size_g;
901 biases += block_size * out_vec_size_g;
902 ws += block_size * out_vec_size_w;
905 for (
int i = block_size; i < in_vec_size; i += block_size) {
909 w_local = *((device vec_w*)ws);
912 (thread uint8_t*)&w_local, x_local, scale, bias, result);
915 scales += block_size * out_vec_size_g;
916 biases += block_size * out_vec_size_g;
917 ws += block_size * out_vec_size_w;
919 if (
static_cast<int>(simd_lid) < remaining) {
923 w_local = *((device vec_w*)ws);
930 (thread uint8_t*)&w_local, x_local, scale, bias, result);
934#pragma clang loop unroll(full)
935 for (
int k = 0; k < tn * pack_factor; k++) {
941#pragma clang loop unroll(full)
942 for (
int k = 0; k < tn * pack_factor; k++) {
943 y[k] =
static_cast<T
>(result[k]);
950 const int group_size,
952 const bool aligned_N,
957 const device uint32_t* w,
958 const device T* scales,
959 const device T* biases,
964 const constant
int& K,
965 const constant
int& N,
966 const constant
int& M,
967 uint3 tid [[threadgroup_position_in_grid]],
968 uint lid [[thread_index_in_threadgroup]],
969 uint simd_gid [[simdgroup_index_in_threadgroup]],
970 uint simd_lid [[thread_index_in_simdgroup]]) {
971 static_assert(BK >=
SIMD_SIZE,
"BK should be larger than SIMD_SIZE");
972 static_assert(BK %
SIMD_SIZE == 0,
"BK should be divisible by SIMD_SIZE");
976 constexpr int WM = 2;
977 constexpr int WN = 2;
978 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
979 constexpr int BK_padded = (BK + 16 /
sizeof(T));
980 constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
983 using mma_t = mlx::steel::
984 BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
998 const int K_w = K * bytes_per_pack / pack_factor;
999 const int K_g = K / group_size;
1000 const int y_row = tid.y * BM;
1001 const int y_col = tid.x * BN;
1003 auto wl = (
const device uint8_t*)w;
1007 scales += y_col * K_g;
1008 biases += y_col * K_g;
1009 y += y_row * N + y_col;
1012 const short num_els =
min(BM, M - y_row);
1013 const short num_outs =
min(BN, N - y_col);
1014 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
1015 loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
1016 mma_t mma_op(simd_gid, simd_lid);
1019 if (!aligned_N && num_outs < BN) {
1020 for (
int k = 0; k < K; k += BK) {
1021 threadgroup_barrier(mem_flags::mem_threadgroup);
1022 loader_x.load_safe(short2(BK, num_els));
1023 loader_w.load_safe(short2(BK, num_outs));
1024 threadgroup_barrier(mem_flags::mem_threadgroup);
1030 for (
int k = 0; k < K; k += BK) {
1031 threadgroup_barrier(mem_flags::mem_threadgroup);
1032 loader_x.load_safe(short2(BK, num_els));
1033 loader_w.load_unsafe();
1034 threadgroup_barrier(mem_flags::mem_threadgroup);
1041 if (!aligned_N && num_outs < BN) {
1042 for (
int k = 0; k < K; k += BK) {
1043 threadgroup_barrier(mem_flags::mem_threadgroup);
1044 loader_x.load_unsafe();
1045 loader_w.load_safe(short2(BK, num_outs));
1046 threadgroup_barrier(mem_flags::mem_threadgroup);
1052 for (
int k = 0; k < K; k += BK) {
1053 threadgroup_barrier(mem_flags::mem_threadgroup);
1054 loader_x.load_unsafe();
1055 loader_w.load_unsafe();
1056 threadgroup_barrier(mem_flags::mem_threadgroup);
1066 threadgroup_barrier(mem_flags::mem_threadgroup);
1067 if (num_els < BM || num_outs < BN) {
1068 mma_op.store_result_safe(y, N, short2(num_outs, num_els));
1070 mma_op.store_result(y, N);
1076 const int group_size,
1082 const device uint32_t* w,
1083 const device T* scales,
1084 const device T* biases,
1089 const constant
int& K,
1090 const constant
int& N,
1091 const constant
int& M,
1092 uint3 tid [[threadgroup_position_in_grid]],
1093 uint lid [[thread_index_in_threadgroup]],
1094 uint simd_gid [[simdgroup_index_in_threadgroup]],
1095 uint simd_lid [[thread_index_in_simdgroup]]) {
1096 static_assert(BK >=
SIMD_SIZE,
"BK should be larger than SIMD_SIZE");
1097 static_assert(BK %
SIMD_SIZE == 0,
"BK should be divisible by SIMD_SIZE");
1101 constexpr int WM = 2;
1102 constexpr int WN = 2;
1103 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
1104 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1105 constexpr int BN_padded = (BN + 16 /
sizeof(T));
1106 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
1107 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
1110 using mma_t = mlx::steel::
1111 BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
1112 using loader_x_t = mlx::steel::
1113 BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
1124 auto wl = (
const device uint8_t*)w;
1127 const int y_row = tid.y * BM;
1128 const int y_col = tid.x * BN;
1130 wl += y_col * bytes_per_pack / pack_factor;
1131 scales += y_col / group_size;
1132 biases += y_col / group_size;
1133 y += y_row * N + y_col;
1136 const short num_els =
min(BM, M - y_row);
1137 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
1138 loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid);
1139 mma_t mma_op(simd_gid, simd_lid);
1142 if ((K % BK) != 0) {
1143 const int k_blocks = K / BK;
1144 for (
int k = 0; k < k_blocks; k++) {
1145 threadgroup_barrier(mem_flags::mem_threadgroup);
1146 loader_x.load_safe(short2(BK, num_els));
1147 loader_w.load_unsafe();
1148 threadgroup_barrier(mem_flags::mem_threadgroup);
1153 const short num_k = K - k_blocks * BK;
1154 threadgroup_barrier(mem_flags::mem_threadgroup);
1155 loader_x.load_safe(short2(num_k, num_els));
1156 loader_w.load_safe(short2(BN, num_k));
1157 threadgroup_barrier(mem_flags::mem_threadgroup);
1160 for (
int k = 0; k < K; k += BK) {
1161 threadgroup_barrier(mem_flags::mem_threadgroup);
1162 loader_x.load_safe(short2(BK, num_els));
1163 loader_w.load_unsafe();
1164 threadgroup_barrier(mem_flags::mem_threadgroup);
1171 if ((K % BK) != 0) {
1172 const int k_blocks = K / BK;
1173 for (
int k = 0; k < k_blocks; k++) {
1174 threadgroup_barrier(mem_flags::mem_threadgroup);
1175 loader_x.load_unsafe();
1176 loader_w.load_unsafe();
1177 threadgroup_barrier(mem_flags::mem_threadgroup);
1182 const short num_k = K - k_blocks * BK;
1183 threadgroup_barrier(mem_flags::mem_threadgroup);
1184 loader_x.load_safe(short2(num_k, BM));
1185 loader_w.load_safe(short2(BN, num_k));
1186 threadgroup_barrier(mem_flags::mem_threadgroup);
1189 for (
int k = 0; k < K; k += BK) {
1190 threadgroup_barrier(mem_flags::mem_threadgroup);
1191 loader_x.load_unsafe();
1192 loader_w.load_unsafe();
1193 threadgroup_barrier(mem_flags::mem_threadgroup);
1202 threadgroup_barrier(mem_flags::mem_threadgroup);
1204 mma_op.store_result_safe(y, N, short2(BN, num_els));
1206 mma_op.store_result(y, N);
1210template <
typename T>
1213 const device uint32_t*& w,
1214 const device T*& scales,
1215 const device T*& biases,
1218 const constant
int& x_batch_ndims,
1219 const constant
int* x_shape,
1220 const constant
size_t* x_strides,
1221 const constant
int& w_batch_ndims,
1222 const constant
int* w_shape,
1223 const constant
size_t* w_strides,
1224 const constant
size_t* s_strides,
1225 const constant
size_t* b_strides,
1226 uint3 tid [[threadgroup_position_in_grid]]) {
1228 uint32_t x_idx = tid.z;
1229 uint32_t w_idx = tid.z;
1230 if (x_batch_ndims == 1) {
1231 x += x_idx * x_strides[0];
1233 x +=
elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1235 if (w_batch_ndims == 1) {
1236 w += w_idx * w_strides[0];
1237 scales += w_idx * s_strides[0];
1238 biases += w_idx * b_strides[0];
1241 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1246 y += tid.z * output_stride;
1249template <
typename T>
1252 const device uint32_t*& w,
1253 const device T*& scales,
1254 const device T*& biases,
1255 const device uint32_t* lhs_indices,
1256 const device uint32_t* rhs_indices,
1259 const constant
int& batch_ndims,
1260 const constant
int* batch_shape,
1261 const constant
size_t* lhs_strides,
1262 const constant
size_t* rhs_strides,
1263 const constant
int& x_batch_ndims,
1264 const constant
int* x_shape,
1265 const constant
size_t* x_strides,
1266 const constant
int& w_batch_ndims,
1267 const constant
int* w_shape,
1268 const constant
size_t* w_strides,
1269 const constant
size_t* s_strides,
1270 const constant
size_t* b_strides,
1271 uint3 tid [[threadgroup_position_in_grid]]) {
1275 if (batch_ndims == 1) {
1276 x_idx = lhs_indices[tid.z * lhs_strides[0]];
1277 w_idx = rhs_indices[tid.z * rhs_strides[0]];
1280 tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
1281 x_idx = lhs_indices[idx.x];
1282 w_idx = rhs_indices[idx.y];
1284 if (x_batch_ndims == 1) {
1285 x += x_idx * x_strides[0];
1287 x +=
elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
1289 if (w_batch_ndims == 1) {
1290 w += w_idx * w_strides[0];
1291 scales += w_idx * s_strides[0];
1292 biases += w_idx * b_strides[0];
1295 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
1300 y += tid.z * output_stride;
1303template <
typename T,
int group_size,
int bits,
int D,
bool batched>
1305 const device uint32_t* w [[buffer(0)]],
1306 const device T* scales [[buffer(1)]],
1307 const device T* biases [[buffer(2)]],
1308 const device T* x [[buffer(3)]],
1309 device T* y [[buffer(4)]],
1310 const constant
int& in_vec_size [[buffer(5)]],
1311 const constant
int& out_vec_size [[buffer(6)]],
1312 const constant
int& x_batch_ndims [[buffer(7)]],
1313 const constant
int* x_shape [[buffer(8)]],
1314 const constant
size_t* x_strides [[buffer(9)]],
1315 const constant
int& w_batch_ndims [[buffer(10)]],
1316 const constant
int* w_shape [[buffer(11)]],
1317 const constant
size_t* w_strides [[buffer(12)]],
1318 const constant
size_t* s_strides [[buffer(13)]],
1319 const constant
size_t* b_strides [[buffer(14)]],
1320 uint3 tid [[threadgroup_position_in_grid]],
1321 uint quad_gid [[quadgroup_index_in_threadgroup]],
1322 uint quad_lid [[thread_index_in_quadgroup]]) {
1354template <
typename T,
int group_size,
int bits,
bool batched>
1356 const device uint32_t* w [[buffer(0)]],
1357 const device T* scales [[buffer(1)]],
1358 const device T* biases [[buffer(2)]],
1359 const device T* x [[buffer(3)]],
1360 device T* y [[buffer(4)]],
1361 const constant
int& in_vec_size [[buffer(5)]],
1362 const constant
int& out_vec_size [[buffer(6)]],
1363 const constant
int& x_batch_ndims [[buffer(7)]],
1364 const constant
int* x_shape [[buffer(8)]],
1365 const constant
size_t* x_strides [[buffer(9)]],
1366 const constant
int& w_batch_ndims [[buffer(10)]],
1367 const constant
int* w_shape [[buffer(11)]],
1368 const constant
size_t* w_strides [[buffer(12)]],
1369 const constant
size_t* s_strides [[buffer(13)]],
1370 const constant
size_t* b_strides [[buffer(14)]],
1371 uint3 tid [[threadgroup_position_in_grid]],
1372 uint simd_gid [[simdgroup_index_in_threadgroup]],
1373 uint simd_lid [[thread_index_in_simdgroup]]) {
1405template <
typename T, const
int group_size, const
int bits,
bool batched>
1407 const device uint32_t* w [[buffer(0)]],
1408 const device T* scales [[buffer(1)]],
1409 const device T* biases [[buffer(2)]],
1410 const device T* x [[buffer(3)]],
1411 device T* y [[buffer(4)]],
1412 const constant
int& in_vec_size [[buffer(5)]],
1413 const constant
int& out_vec_size [[buffer(6)]],
1414 const constant
int& x_batch_ndims [[buffer(7)]],
1415 const constant
int* x_shape [[buffer(8)]],
1416 const constant
size_t* x_strides [[buffer(9)]],
1417 const constant
int& w_batch_ndims [[buffer(10)]],
1418 const constant
int* w_shape [[buffer(11)]],
1419 const constant
size_t* w_strides [[buffer(12)]],
1420 const constant
size_t* s_strides [[buffer(13)]],
1421 const constant
size_t* b_strides [[buffer(14)]],
1422 uint3 tid [[threadgroup_position_in_grid]],
1423 uint simd_gid [[simdgroup_index_in_threadgroup]],
1424 uint simd_lid [[thread_index_in_simdgroup]]) {
1456template <
typename T, const
int group_size, const
int bits,
bool batched>
1458 const device uint32_t* w [[buffer(0)]],
1459 const device T* scales [[buffer(1)]],
1460 const device T* biases [[buffer(2)]],
1461 const device T* x [[buffer(3)]],
1462 device T* y [[buffer(4)]],
1463 const constant
int& in_vec_size [[buffer(5)]],
1464 const constant
int& out_vec_size [[buffer(6)]],
1465 const constant
int& x_batch_ndims [[buffer(7)]],
1466 const constant
int* x_shape [[buffer(8)]],
1467 const constant
size_t* x_strides [[buffer(9)]],
1468 const constant
int& w_batch_ndims [[buffer(10)]],
1469 const constant
int* w_shape [[buffer(11)]],
1470 const constant
size_t* w_strides [[buffer(12)]],
1471 const constant
size_t* s_strides [[buffer(13)]],
1472 const constant
size_t* b_strides [[buffer(14)]],
1473 uint3 tid [[threadgroup_position_in_grid]],
1474 uint simd_gid [[simdgroup_index_in_threadgroup]],
1475 uint simd_lid [[thread_index_in_simdgroup]]) {
1507template <
typename T, const
int group_size, const
int bits,
int split_k = 32>
1509 const device uint32_t* w [[buffer(0)]],
1510 const device T* scales [[buffer(1)]],
1511 const device T* biases [[buffer(2)]],
1512 const device T* x [[buffer(3)]],
1513 device T* y [[buffer(4)]],
1514 const constant
int& in_vec_size [[buffer(5)]],
1515 const constant
int& out_vec_size [[buffer(6)]],
1516 const constant
int& x_batch_ndims [[buffer(7)]],
1517 const constant
int* x_shape [[buffer(8)]],
1518 const constant
size_t* x_strides [[buffer(9)]],
1519 const constant
int& w_batch_ndims [[buffer(10)]],
1520 const constant
int* w_shape [[buffer(11)]],
1521 const constant
size_t* w_strides [[buffer(12)]],
1522 const constant
size_t* s_strides [[buffer(13)]],
1523 const constant
size_t* b_strides [[buffer(14)]],
1524 const constant
int& final_block_size [[buffer(15)]],
1525 uint3 tid [[threadgroup_position_in_grid]],
1526 uint simd_gid [[simdgroup_index_in_threadgroup]],
1527 uint simd_lid [[thread_index_in_simdgroup]]) {
1546 int in_vec_size_adj =
1547 tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
1564 const int group_size,
1566 const bool aligned_N,
1572 const device uint32_t* w [[buffer(0)]],
1573 const device T* scales [[buffer(1)]],
1574 const device T* biases [[buffer(2)]],
1575 const device T* x [[buffer(3)]],
1576 device T* y [[buffer(4)]],
1577 const constant
int& K [[buffer(5)]],
1578 const constant
int& N [[buffer(6)]],
1579 const constant
int& M [[buffer(7)]],
1580 const constant
int& x_batch_ndims [[buffer(8)]],
1581 const constant
int* x_shape [[buffer(9)]],
1582 const constant
size_t* x_strides [[buffer(10)]],
1583 const constant
int& w_batch_ndims [[buffer(11)]],
1584 const constant
int* w_shape [[buffer(12)]],
1585 const constant
size_t* w_strides [[buffer(13)]],
1586 const constant
size_t* s_strides [[buffer(14)]],
1587 const constant
size_t* b_strides [[buffer(15)]],
1588 uint3 tid [[threadgroup_position_in_grid]],
1589 uint lid [[thread_index_in_threadgroup]],
1590 uint simd_gid [[simdgroup_index_in_threadgroup]],
1591 uint simd_lid [[thread_index_in_simdgroup]]) {
1594 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1596 threadgroup T Xs[BM * BK_padded];
1597 threadgroup T Ws[BN * BK_padded];
1618 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1623 const int group_size,
1630 const device uint32_t* w [[buffer(0)]],
1631 const device T* scales [[buffer(1)]],
1632 const device T* biases [[buffer(2)]],
1633 const device T* x [[buffer(3)]],
1634 device T* y [[buffer(4)]],
1635 const constant
int& K [[buffer(5)]],
1636 const constant
int& N [[buffer(6)]],
1637 const constant
int& M [[buffer(7)]],
1638 const constant
int& x_batch_ndims [[buffer(8)]],
1639 const constant
int* x_shape [[buffer(9)]],
1640 const constant
size_t* x_strides [[buffer(10)]],
1641 const constant
int& w_batch_ndims [[buffer(11)]],
1642 const constant
int* w_shape [[buffer(12)]],
1643 const constant
size_t* w_strides [[buffer(13)]],
1644 const constant
size_t* s_strides [[buffer(14)]],
1645 const constant
size_t* b_strides [[buffer(15)]],
1646 uint3 tid [[threadgroup_position_in_grid]],
1647 uint lid [[thread_index_in_threadgroup]],
1648 uint simd_gid [[simdgroup_index_in_threadgroup]],
1649 uint simd_lid [[thread_index_in_simdgroup]]) {
1652 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1653 constexpr int BN_padded = (BN + 16 /
sizeof(T));
1655 threadgroup T Xs[BM * BK_padded];
1656 threadgroup T Ws[BK * BN_padded];
1678 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1681template <
typename T,
int group_size,
int bits>
1683 const device uint32_t* w [[buffer(0)]],
1684 const device T* scales [[buffer(1)]],
1685 const device T* biases [[buffer(2)]],
1686 const device T* x [[buffer(3)]],
1687 device T* y [[buffer(4)]],
1688 const constant
int& in_vec_size [[buffer(5)]],
1689 const constant
int& out_vec_size [[buffer(6)]],
1690 const constant
int& x_batch_ndims [[buffer(7)]],
1691 const constant
int* x_shape [[buffer(8)]],
1692 const constant
size_t* x_strides [[buffer(9)]],
1693 const constant
int& w_batch_ndims [[buffer(10)]],
1694 const constant
int* w_shape [[buffer(11)]],
1695 const constant
size_t* w_strides [[buffer(12)]],
1696 const constant
size_t* s_strides [[buffer(13)]],
1697 const constant
size_t* b_strides [[buffer(14)]],
1698 const constant
int& batch_ndims [[buffer(15)]],
1699 const constant
int* batch_shape [[buffer(16)]],
1700 const device uint32_t* lhs_indices [[buffer(17)]],
1701 const device uint32_t* rhs_indices [[buffer(18)]],
1702 const constant
size_t* lhs_strides [[buffer(19)]],
1703 const constant
size_t* rhs_strides [[buffer(20)]],
1704 uint3 tid [[threadgroup_position_in_grid]],
1705 uint simd_gid [[simdgroup_index_in_threadgroup]],
1706 uint simd_lid [[thread_index_in_simdgroup]]) {
1742template <
typename T,
int group_size,
int bits>
1744 const device uint32_t* w [[buffer(0)]],
1745 const device T* scales [[buffer(1)]],
1746 const device T* biases [[buffer(2)]],
1747 const device T* x [[buffer(3)]],
1748 device T* y [[buffer(4)]],
1749 const constant
int& in_vec_size [[buffer(5)]],
1750 const constant
int& out_vec_size [[buffer(6)]],
1751 const constant
int& x_batch_ndims [[buffer(7)]],
1752 const constant
int* x_shape [[buffer(8)]],
1753 const constant
size_t* x_strides [[buffer(9)]],
1754 const constant
int& w_batch_ndims [[buffer(10)]],
1755 const constant
int* w_shape [[buffer(11)]],
1756 const constant
size_t* w_strides [[buffer(12)]],
1757 const constant
size_t* s_strides [[buffer(13)]],
1758 const constant
size_t* b_strides [[buffer(14)]],
1759 const constant
int& batch_ndims [[buffer(15)]],
1760 const constant
int* batch_shape [[buffer(16)]],
1761 const device uint32_t* lhs_indices [[buffer(17)]],
1762 const device uint32_t* rhs_indices [[buffer(18)]],
1763 const constant
size_t* lhs_strides [[buffer(19)]],
1764 const constant
size_t* rhs_strides [[buffer(20)]],
1765 uint3 tid [[threadgroup_position_in_grid]],
1766 uint simd_gid [[simdgroup_index_in_threadgroup]],
1767 uint simd_lid [[thread_index_in_simdgroup]]) {
1803template <
typename T,
int group_size,
int bits>
1805 const device uint32_t* w [[buffer(0)]],
1806 const device T* scales [[buffer(1)]],
1807 const device T* biases [[buffer(2)]],
1808 const device T* x [[buffer(3)]],
1809 device T* y [[buffer(4)]],
1810 const constant
int& in_vec_size [[buffer(5)]],
1811 const constant
int& out_vec_size [[buffer(6)]],
1812 const constant
int& x_batch_ndims [[buffer(7)]],
1813 const constant
int* x_shape [[buffer(8)]],
1814 const constant
size_t* x_strides [[buffer(9)]],
1815 const constant
int& w_batch_ndims [[buffer(10)]],
1816 const constant
int* w_shape [[buffer(11)]],
1817 const constant
size_t* w_strides [[buffer(12)]],
1818 const constant
size_t* s_strides [[buffer(13)]],
1819 const constant
size_t* b_strides [[buffer(14)]],
1820 const constant
int& batch_ndims [[buffer(15)]],
1821 const constant
int* batch_shape [[buffer(16)]],
1822 const device uint32_t* lhs_indices [[buffer(17)]],
1823 const device uint32_t* rhs_indices [[buffer(18)]],
1824 const constant
size_t* lhs_strides [[buffer(19)]],
1825 const constant
size_t* rhs_strides [[buffer(20)]],
1826 uint3 tid [[threadgroup_position_in_grid]],
1827 uint simd_gid [[simdgroup_index_in_threadgroup]],
1828 uint simd_lid [[thread_index_in_simdgroup]]) {
1866 const int group_size,
1868 const bool aligned_N,
1873 const device uint32_t* w [[buffer(0)]],
1874 const device T* scales [[buffer(1)]],
1875 const device T* biases [[buffer(2)]],
1876 const device T* x [[buffer(3)]],
1877 device T* y [[buffer(4)]],
1878 const constant
int& K [[buffer(5)]],
1879 const constant
int& N [[buffer(6)]],
1880 const constant
int& M [[buffer(7)]],
1881 const constant
int& x_batch_ndims [[buffer(8)]],
1882 const constant
int* x_shape [[buffer(9)]],
1883 const constant
size_t* x_strides [[buffer(10)]],
1884 const constant
int& w_batch_ndims [[buffer(11)]],
1885 const constant
int* w_shape [[buffer(12)]],
1886 const constant
size_t* w_strides [[buffer(13)]],
1887 const constant
size_t* s_strides [[buffer(14)]],
1888 const constant
size_t* b_strides [[buffer(15)]],
1889 const constant
int& batch_ndims [[buffer(16)]],
1890 const constant
int* batch_shape [[buffer(17)]],
1891 const device uint32_t* lhs_indices [[buffer(18)]],
1892 const device uint32_t* rhs_indices [[buffer(19)]],
1893 const constant
size_t* lhs_strides [[buffer(20)]],
1894 const constant
size_t* rhs_strides [[buffer(21)]],
1895 uint3 tid [[threadgroup_position_in_grid]],
1896 uint lid [[thread_index_in_threadgroup]],
1897 uint simd_gid [[simdgroup_index_in_threadgroup]],
1898 uint simd_lid [[thread_index_in_simdgroup]]) {
1901 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1903 threadgroup T Xs[BM * BK_padded];
1904 threadgroup T Ws[BN * BK_padded];
1929 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1934 const int group_size,
1940 const device uint32_t* w [[buffer(0)]],
1941 const device T* scales [[buffer(1)]],
1942 const device T* biases [[buffer(2)]],
1943 const device T* x [[buffer(3)]],
1944 device T* y [[buffer(4)]],
1945 const constant
int& K [[buffer(5)]],
1946 const constant
int& N [[buffer(6)]],
1947 const constant
int& M [[buffer(7)]],
1948 const constant
int& x_batch_ndims [[buffer(8)]],
1949 const constant
int* x_shape [[buffer(9)]],
1950 const constant
size_t* x_strides [[buffer(10)]],
1951 const constant
int& w_batch_ndims [[buffer(11)]],
1952 const constant
int* w_shape [[buffer(12)]],
1953 const constant
size_t* w_strides [[buffer(13)]],
1954 const constant
size_t* s_strides [[buffer(14)]],
1955 const constant
size_t* b_strides [[buffer(15)]],
1956 const constant
int& batch_ndims [[buffer(16)]],
1957 const constant
int* batch_shape [[buffer(17)]],
1958 const device uint32_t* lhs_indices [[buffer(18)]],
1959 const device uint32_t* rhs_indices [[buffer(19)]],
1960 const constant
size_t* lhs_strides [[buffer(20)]],
1961 const constant
size_t* rhs_strides [[buffer(21)]],
1962 uint3 tid [[threadgroup_position_in_grid]],
1963 uint lid [[thread_index_in_threadgroup]],
1964 uint simd_gid [[simdgroup_index_in_threadgroup]],
1965 uint simd_lid [[thread_index_in_simdgroup]]) {
1968 constexpr int BK_padded = (BK + 16 /
sizeof(T));
1969 constexpr int BN_padded = (BN + 16 /
sizeof(T));
1971 threadgroup T Xs[BM * BK_padded];
1972 threadgroup T Ws[BK * BN_padded];
1997 w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
2000template <
typename T, const
int group_size, const
int bits>
2002 const device T* w [[buffer(0)]],
2003 device uint8_t* out [[buffer(1)]],
2004 device T* scales [[buffer(2)]],
2005 device T* biases [[buffer(3)]],
2006 uint2 index [[thread_position_in_grid]],
2007 uint2 grid_dim [[threads_per_grid]]) {
2008 constexpr T eps = T(1e-7);
2010 constexpr T n_bins = (1 << bits) - 1;
2011 constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
2012 constexpr int values_per_reduce = group_size /
simd_size;
2013 constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
2014 constexpr int writes_per_pack =
2015 writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
2016 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
2017 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
2021 "Group size must be divisible by simd size.");
2023 size_t offset = index.x + grid_dim.x * size_t(index.y);
2024 size_t in_index = offset * values_per_reduce;
2025 size_t out_index = power_of_2_bits
2026 ? offset * writes_per_pack
2027 : offset * bytes_per_pack / writes_per_reduce;
2029 T w_thread[values_per_reduce];
2033#pragma clang loop unroll(full)
2034 for (
int i = 0; i < values_per_reduce; i++) {
2035 T val = w[in_index + i];
2037 w_min =
min(w_min, val);
2038 w_max =
max(w_max, val);
2044 T scale =
max((w_max - w_min) / n_bins, eps);
2045 bool side =
abs(w_min) >
abs(w_max);
2046 scale = side ? scale : -scale;
2047 T edge = side ? w_min : w_max;
2048 T q0 =
round(edge / scale);
2049 bool at_zero = q0 == 0.0f;
2050 scale = at_zero ? scale : edge / q0;
2051 T bias = at_zero ? T(0) : edge;
2054 size_t gindex = in_index / group_size;
2055 if (in_index % group_size == 0) {
2056 scales[gindex] = scale;
2057 biases[gindex] = bias;
2061 uint32_t output = 0;
2063#pragma clang loop unroll(full)
2064 for (
int i = 0; i < values_per_reduce; i++) {
2065 uint8_t val =
min(
round((w_thread[i] - bias) / scale), n_bins);
2069 output += val << (bits * (i % packs_per_int));
2072 if (packs_per_int < values_per_reduce &&
2073 i % packs_per_int == packs_per_int - 1) {
2074 out[out_index + i / packs_per_int] = output;
2077#pragma clang loop unroll(full)
2078 for (
int j = 1; j < writes_per_reduce; j++) {
2080 output += sval << (bits * (j * values_per_reduce + i));
2084 if (bits == 3 || bits == 6) {
2085 if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) {
2086 out[out_index] = output & 0xff;
2087 out[out_index + 1] = (output & 0xff00) >> 8;
2088 out[out_index + 2] = (output & 0xff0000) >> 16;
2091 if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
2092 out[out_index / writes_per_reduce] = output;
2097template <
typename T, const
int group_size, const
int bits>
2099 const device uint8_t* w [[buffer(0)]],
2100 const device T* scales [[buffer(1)]],
2101 const device T* biases [[buffer(2)]],
2102 device T* out [[buffer(3)]],
2103 uint2 index [[thread_position_in_grid]],
2104 uint2 grid_dim [[threads_per_grid]]) {
2105 constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
2106 constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
2107 constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
2109 size_t offset = index.x + grid_dim.x * size_t(index.y);
2110 size_t oindex = offset * packs_per_int;
2111 size_t gindex = oindex / group_size;
2112 T scale = scales[gindex];
2113 T bias = biases[gindex];
2118 w += offset * bytes_per_pack;
2119 out[0] = (w[0] & 0x7) * scale + bias;
2120 out[1] = ((w[0] & 0x38) >> 3) * scale + bias;
2121 out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
2122 out[3] = ((w[1] & 0xe) >> 1) * scale + bias;
2123 out[4] = ((w[1] & 0x70) >> 4) * scale + bias;
2124 out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
2125 out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
2126 out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
2128 }
else if (bits == 6) {
2129 w += offset * bytes_per_pack;
2130 out[0] = (w[0] & 0x3f) * scale + bias;
2131 out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
2132 out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
2133 out[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
2135 uint val = w[offset];
2136#pragma clang loop unroll(full)
2137 for (
int i = 0; i < packs_per_int; i++) {
2140 d = (val >> (bits * i)) & 0x03;
2141 }
else if (bits == 4) {
2142 d = (val >> (bits * i)) & 0x0f;
2143 }
else if (bits == 8) {
2146 out[i] = scale * d + bias;
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
#define MLX_MTL_CONST
Definition quantized.h:8
U qdot_safe(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum, int N)
Definition quantized.h:225
METAL_FUNC void qmm_n_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1081
METAL_FUNC void qvm_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const int in_vec_size, const int out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:843
void bs_qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1939
void qmm_n(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1629
void affine_quantize(const device T *w, device uint8_t *out, device T *scales, device T *biases, uint2 index, uint2 grid_dim)
Definition quantized.h:2001
void bs_qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1682
void affine_dequantize(const device uint8_t *w, const device T *scales, const device T *biases, device T *out, uint2 index, uint2 grid_dim)
Definition quantized.h:2098
static constant constexpr const int SIMD_SIZE
Definition quantized.h:10
void qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1406
void bs_qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1804
void qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1355
void qmv_quad(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:1304
static constant constexpr const int QUAD_SIZE
Definition quantized.h:11
U load_vector(const device T *x, thread U *x_thread)
Definition quantized.h:14
METAL_FUNC void qmv_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:688
U load_vector_safe(const device T *x, thread U *x_thread, int N)
Definition quantized.h:77
void bs_qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1872
U qdot(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum)
Definition quantized.h:145
void qvm_split_k(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &final_block_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1508
METAL_FUNC void qmv_fast_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:620
void qmm_t(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &K, const constant int &N, const constant int &M, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1571
METAL_FUNC void adjust_matrix_offsets(const device T *&x, const device uint32_t *&w, const device T *&scales, const device T *&biases, device T *&y, int output_stride, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid)
Definition quantized.h:1211
void bs_qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &batch_ndims, const constant int *batch_shape, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, const constant size_t *lhs_strides, const constant size_t *rhs_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1743
METAL_FUNC void qmv_quad_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint quad_gid, uint quad_lid)
Definition quantized.h:563
void qvm(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1457
void qouter(const thread uint8_t *w, U x, U scale, U bias, thread U *result)
Definition quantized.h:307
void dequantize(const device uint8_t *w, U scale, U bias, threadgroup U *w_local)
Definition quantized.h:372
METAL_FUNC void qmm_t_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &K, const constant int &N, const constant int &M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:956
Definition quantized.h:443
const int group_stride
Definition quantized.h:464
static constant constexpr const short BCOLS_PACKED
Definition quantized.h:456
const device T * biases
Definition quantized.h:473
short group_step_cnt
Definition quantized.h:463
static constant constexpr const short group_steps
Definition quantized.h:459
const short thread_idx
Definition quantized.h:466
QuantizedBlockLoader(const device uint8_t *src_, const device T *scales_, const device T *biases_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition quantized.h:475
const device T * scales
Definition quantized.h:472
static constant constexpr const short n_reads
Definition quantized.h:457
void next()
Definition quantized.h:541
void load_safe(short2 src_tile_dim) const
Definition quantized.h:511
const int src_ld
Definition quantized.h:461
const short bi
Definition quantized.h:467
void load_unsafe() const
Definition quantized.h:498
static constant constexpr const short pack_factor
Definition quantized.h:454
threadgroup T * dst
Definition quantized.h:470
const device uint8_t * src
Definition quantized.h:471
const int tile_stride
Definition quantized.h:462
static constant constexpr const short bytes_per_pack
Definition quantized.h:455
const short bj
Definition quantized.h:468