MLX
Loading...
Searching...
No Matches
quantized.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#include <metal_simdgroup>
4#include <metal_stdlib>
5
6using namespace metal;
7
8#define MLX_MTL_CONST static constant constexpr const
9
11
12template <typename T, typename U, int values_per_thread, int bits>
13inline U load_vector(const device T* x, thread U* x_thread) {
14 static_assert(
15 bits == 2 || bits == 4 || bits == 8,
16 "Template undefined for bits not in {2, 4, 8}");
17
18 U sum = 0;
19
20 if (bits == 2) {
21 for (int i = 0; i < values_per_thread; i += 4) {
22 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
23 x_thread[i] = x[i];
24 x_thread[i + 1] = x[i + 1] / 4.0f;
25 x_thread[i + 2] = x[i + 2] / 16.0f;
26 x_thread[i + 3] = x[i + 3] / 64.0f;
27 }
28 }
29
30 else if (bits == 4) {
31 for (int i = 0; i < values_per_thread; i += 4) {
32 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
33 x_thread[i] = x[i];
34 x_thread[i + 1] = x[i + 1] / 16.0f;
35 x_thread[i + 2] = x[i + 2] / 256.0f;
36 x_thread[i + 3] = x[i + 3] / 4096.0f;
37 }
38 }
39
40 else if (bits == 8) {
41 for (int i = 0; i < values_per_thread; i++) {
42 sum += x[i];
43 x_thread[i] = x[i];
44 }
45 }
46
47 return sum;
48}
49
50template <typename T, typename U, int values_per_thread, int bits>
51inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
52 static_assert(
53 bits == 2 || bits == 4 || bits == 8,
54 "Template undefined for bits not in {2, 4, 8}");
55
56 U sum = 0;
57
58 if (bits == 2) {
59 for (int i = 0; i < N; i += 4) {
60 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
61 x_thread[i] = x[i];
62 x_thread[i + 1] = x[i + 1] / 4.0f;
63 x_thread[i + 2] = x[i + 2] / 16.0f;
64 x_thread[i + 3] = x[i + 3] / 64.0f;
65 }
66 for (int i = N; i < values_per_thread; i++) {
67 x_thread[i] = 0;
68 }
69 }
70
71 else if (bits == 4) {
72 for (int i = 0; i < N; i += 4) {
73 sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
74 x_thread[i] = x[i];
75 x_thread[i + 1] = x[i + 1] / 16.0f;
76 x_thread[i + 2] = x[i + 2] / 256.0f;
77 x_thread[i + 3] = x[i + 3] / 4096.0f;
78 }
79 for (int i = N; i < values_per_thread; i++) {
80 x_thread[i] = 0;
81 }
82 }
83
84 else if (bits == 8) {
85 for (int i = 0; i < N; i++) {
86 sum += x[i];
87 x_thread[i] = x[i];
88 }
89 for (int i = N; i < values_per_thread; i++) {
90 x_thread[i] = 0;
91 }
92 }
93
94 return sum;
95}
96
97template <typename U, int values_per_thread, int bits>
98inline U qdot(
99 const device uint8_t* w,
100 const thread U* x_thread,
101 U scale,
102 U bias,
103 U sum) {
104 static_assert(
105 bits == 2 || bits == 4 || bits == 8,
106 "Template undefined for bits not in {2, 4, 8}");
107
108 U accum = 0;
109
110 if (bits == 2) {
111 for (int i = 0; i < (values_per_thread / 4); i++) {
112 accum +=
113 (x_thread[4 * i] * (w[i] & 0x03) +
114 x_thread[4 * i + 1] * (w[i] & 0x0c) +
115 x_thread[4 * i + 2] * (w[i] & 0x30) +
116 x_thread[4 * i + 3] * (w[i] & 0xc0));
117 }
118 }
119
120 else if (bits == 4) {
121 const device uint16_t* ws = (const device uint16_t*)w;
122 for (int i = 0; i < (values_per_thread / 4); i++) {
123 accum +=
124 (x_thread[4 * i] * (ws[i] & 0x000f) +
125 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
126 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
127 x_thread[4 * i + 3] * (ws[i] & 0xf000));
128 }
129 }
130
131 else if (bits == 8) {
132 for (int i = 0; i < values_per_thread; i++) {
133 accum += x_thread[i] * w[i];
134 }
135 }
136
137 return scale * accum + sum * bias;
138}
139
140template <typename U, int values_per_thread, int bits>
141inline U qdot_safe(
142 const device uint8_t* w,
143 const thread U* x_thread,
144 U scale,
145 U bias,
146 U sum,
147 int N) {
148 static_assert(
149 bits == 2 || bits == 4 || bits == 8,
150 "Template undefined for bits not in {2, 4, 8}");
151
152 U accum = 0;
153
154 if (bits == 2) {
155 for (int i = 0; i < (N / 4); i++) {
156 accum +=
157 (x_thread[4 * i] * (w[i] & 0x03) +
158 x_thread[4 * i + 1] * (w[i] & 0x0c) +
159 x_thread[4 * i + 2] * (w[i] & 0x30) +
160 x_thread[4 * i + 3] * (w[i] & 0xc0));
161 }
162 }
163
164 else if (bits == 4) {
165 const device uint16_t* ws = (const device uint16_t*)w;
166 for (int i = 0; i < (N / 4); i++) {
167 accum +=
168 (x_thread[4 * i] * (ws[i] & 0x000f) +
169 x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
170 x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
171 x_thread[4 * i + 3] * (ws[i] & 0xf000));
172 }
173 }
174
175 else if (bits == 8) {
176 for (int i = 0; i < N; i++) {
177 accum += x_thread[i] * w[i];
178 }
179 }
180
181 return scale * accum + sum * bias;
182}
183
184template <typename U, int values_per_thread, int bits>
185inline void
186qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
187 static_assert(
188 bits == 2 || bits == 4 || bits == 8,
189 "Template undefined for bits not in {2, 4, 8}");
190
191 if (bits == 2) {
192 U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
193 for (int i = 0; i < (values_per_thread / 4); i++) {
194 result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
195 result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
196 result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
197 result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
198 }
199 }
200
201 else if (bits == 4) {
202 U s[2] = {scale, scale / 16.0f};
203 for (int i = 0; i < (values_per_thread / 2); i++) {
204 result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
205 result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
206 }
207 }
208
209 else if (bits == 8) {
210 for (int i = 0; i < values_per_thread; i++) {
211 result[i] += x * (scale * w[i] + bias);
212 }
213 }
214}
215
216template <typename U, int N, int bits>
217inline void
218dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
219 static_assert(
220 bits == 2 || bits == 4 || bits == 8,
221 "Template undefined for bits not in {2, 4, 8}");
222
223 if (bits == 2) {
224 U s[4] = {
225 scale,
226 scale / static_cast<U>(4.0f),
227 scale / static_cast<U>(16.0f),
228 scale / static_cast<U>(64.0f)};
229 for (int i = 0; i < (N / 4); i++) {
230 w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
231 w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
232 w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
233 w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
234 }
235 }
236
237 else if (bits == 4) {
238 U s[2] = {scale, scale / static_cast<U>(16.0f)};
239 for (int i = 0; i < (N / 2); i++) {
240 w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
241 w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
242 }
243 }
244
245 else if (bits == 8) {
246 for (int i = 0; i < N; i++) {
247 w_local[i] = scale * w[i] + bias;
248 }
249 }
250}
251
252template <
253 typename T,
254 short BROWS,
255 short BCOLS,
256 short dst_ld,
257 short reduction_dim,
258 short tgp_size,
259 short group_size,
260 short bits>
262 static_assert(
263 BCOLS <= group_size,
264 "The group size should be larger than the columns");
265 static_assert(
266 group_size % BCOLS == 0,
267 "The group size should be divisible by the columns");
268 static_assert(
269 bits == 2 || bits == 4 || bits == 8,
270 "Template undefined for bits not in {2, 4, 8}");
271
272 MLX_MTL_CONST short pack_factor = 32 / bits;
275 (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
276 MLX_MTL_CONST short group_steps = group_size / BCOLS;
277
278 const int src_ld;
279 const int tile_stride;
281 const int group_stride;
282
283 const short thread_idx;
284 const short bi;
285 const short bj;
286
287 threadgroup T* dst;
288 const device uint32_t* src;
289 const device T* scales;
290 const device T* biases;
291
293 const device uint32_t* src_,
294 const device T* scales_,
295 const device T* biases_,
296 const int src_ld_,
297 threadgroup T* dst_,
298 ushort simd_group_id [[simdgroup_index_in_threadgroup]],
299 ushort simd_lane_id [[thread_index_in_simdgroup]])
300 : src_ld(src_ld_),
302 reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
304 group_stride(BROWS * src_ld / group_size),
305 thread_idx(simd_group_id * 32 + simd_lane_id),
308 dst(dst_ + bi * dst_ld + bj * pack_factor),
309 src(src_ + bi * src_ld / pack_factor + bj),
310 scales(scales_ + bi * src_ld / group_size),
311 biases(biases_ + bi * src_ld / group_size) {}
312
313 void load_unsafe() const {
314 if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
315 return;
316 }
317
318 T scale = *scales;
319 T bias = *biases;
320 for (int i = 0; i < n_reads; i++) {
321 dequantize<T, pack_factor, bits>(
322 (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
323 }
324 }
325
326 void load_safe(short2 src_tile_dim) const {
327 if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
328 return;
329 }
330
331 if (reduction_dim == 1 && bi >= src_tile_dim.y) {
332 for (int i = 0; i < n_reads * pack_factor; i++) {
333 dst[i] = T(0);
334 }
335 return;
336 }
337
338 if (reduction_dim == 0 && bi >= src_tile_dim.x) {
339 for (int i = 0; i < n_reads * pack_factor; i++) {
340 dst[i] = T(0);
341 }
342 return;
343 }
344
345 T scale = *scales;
346 T bias = *biases;
347 for (int i = 0; i < n_reads; i++) {
348 dequantize<T, pack_factor, bits>(
349 (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
350 }
351 }
352
353 void next() {
354 src += tile_stride;
355 if (reduction_dim == 1) {
356 if (group_steps > 1) {
359 group_step_cnt = 0;
360 scales++;
361 biases++;
362 }
363 } else {
364 scales++;
365 biases++;
366 }
367 } else {
370 }
371 }
372};
373
374template <typename T, int group_size, int bits>
375METAL_FUNC void qmv_fast_impl(
376 const device uint32_t* w,
377 const device T* scales,
378 const device T* biases,
379 const device T* x,
380 device T* y,
381 const constant int& in_vec_size,
382 const constant int& out_vec_size,
383 uint3 tid [[threadgroup_position_in_grid]],
384 uint simd_gid [[simdgroup_index_in_threadgroup]],
385 uint simd_lid [[thread_index_in_simdgroup]]) {
386 constexpr int packs_per_thread = bits > 2 ? 2 : 1;
387 constexpr int num_simdgroups = 2;
388 constexpr int results_per_simdgroup = 4;
389 constexpr int pack_factor = 32 / bits;
390 constexpr int values_per_thread = pack_factor * packs_per_thread;
391 constexpr int block_size = values_per_thread * SIMD_SIZE;
392 constexpr int scale_step_per_thread = group_size / values_per_thread;
393
394 typedef float U;
395
396 thread U x_thread[values_per_thread];
397 thread U result[results_per_simdgroup] = {0};
398
399 // Adjust positions
400 const int in_vec_size_w = in_vec_size / pack_factor;
401 const int in_vec_size_g = in_vec_size / group_size;
402 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
403 simd_gid * results_per_simdgroup;
404 w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
405 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
406 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
407 x += tid.y * in_vec_size + simd_lid * values_per_thread;
408 y += tid.y * out_vec_size + out_row;
409
410 for (int k = 0; k < in_vec_size; k += block_size) {
411 U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
412
413 for (int row = 0; row < results_per_simdgroup; row++) {
414 const device uint8_t* wl =
415 (const device uint8_t*)(w + row * in_vec_size_w);
416 const device T* sl = scales + row * in_vec_size_g;
417 const device T* bl = biases + row * in_vec_size_g;
418
419 U s = sl[0];
420 U b = bl[0];
421 result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
422 }
423
424 w += block_size / pack_factor;
425 scales += block_size / group_size;
426 biases += block_size / group_size;
427 x += block_size;
428 }
429
430 for (int row = 0; row < results_per_simdgroup; row++) {
431 result[row] = simd_sum(result[row]);
432 if (simd_lid == 0) {
433 y[row] = static_cast<T>(result[row]);
434 }
435 }
436}
437
438template <typename T, int group_size, int bits>
439METAL_FUNC void qmv_impl(
440 const device uint32_t* w,
441 const device T* scales,
442 const device T* biases,
443 const device T* x,
444 device T* y,
445 const constant int& in_vec_size,
446 const constant int& out_vec_size,
447 uint3 tid [[threadgroup_position_in_grid]],
448 uint simd_gid [[simdgroup_index_in_threadgroup]],
449 uint simd_lid [[thread_index_in_simdgroup]]) {
450 constexpr int num_simdgroups = 2;
451 constexpr int results_per_simdgroup = 4;
452 constexpr int packs_per_thread = 1;
453 constexpr int pack_factor = 32 / bits;
454 constexpr int values_per_thread = pack_factor * packs_per_thread;
455 constexpr int block_size = values_per_thread * SIMD_SIZE;
456 constexpr int scale_step_per_thread = group_size / values_per_thread;
457
458 typedef float U;
459
460 thread U x_thread[values_per_thread];
461 thread U result[results_per_simdgroup] = {0};
462
463 // Adjust positions
464 const int in_vec_size_w = in_vec_size / pack_factor;
465 const int in_vec_size_g = in_vec_size / group_size;
466 const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
467 simd_gid * results_per_simdgroup;
468 const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
469
470 if (out_row >= out_vec_size) {
471 return;
472 }
473
474 // In this case we need to properly guard all our reads because there isn't
475 // even 1 tile in the matrix
476 if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
477 w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
478 scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
479 biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
480 x += tid.y * in_vec_size + simd_lid * values_per_thread;
481 y += tid.y * out_vec_size + out_row;
482
483 int k = 0;
484 for (; k < in_vec_size - block_size; k += block_size) {
485 U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
486
487 for (int row = 0; out_row + row < out_vec_size; row++) {
488 const device uint8_t* wl =
489 (const device uint8_t*)(w + row * in_vec_size_w);
490 const device T* sl = scales + row * in_vec_size_g;
491 const device T* bl = biases + row * in_vec_size_g;
492
493 U s = sl[0];
494 U b = bl[0];
495 result[row] +=
496 qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
497 }
498
499 w += block_size / pack_factor;
500 scales += block_size / group_size;
501 biases += block_size / group_size;
502 x += block_size;
503 }
504 const int remaining = clamp(
505 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
506 0,
507 values_per_thread);
508 U sum =
509 load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
510
511 for (int row = 0; out_row + row < out_vec_size; row++) {
512 const device uint8_t* wl =
513 (const device uint8_t*)(w + row * in_vec_size_w);
514 const device T* sl = scales + row * in_vec_size_g;
515 const device T* bl = biases + row * in_vec_size_g;
516
517 U s = sl[0];
518 U b = bl[0];
519 result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
520 }
521
522 for (int row = 0; out_row + row < out_vec_size; row++) {
523 result[row] = simd_sum(result[row]);
524 if (simd_lid == 0) {
525 y[row] = static_cast<T>(result[row]);
526 }
527 }
528 }
529
530 // In this case the last tile is moved back to redo some output values
531 else {
532 w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread;
533 scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
534 biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
535 x += tid.y * in_vec_size + simd_lid * values_per_thread;
536 y += tid.y * out_vec_size + used_out_row;
537
538 int k = 0;
539 for (; k < in_vec_size - block_size; k += block_size) {
540 U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
541
542 for (int row = 0; row < results_per_simdgroup; row++) {
543 const device uint8_t* wl =
544 (const device uint8_t*)(w + row * in_vec_size_w);
545 const device T* sl = scales + row * in_vec_size_g;
546 const device T* bl = biases + row * in_vec_size_g;
547
548 U s = sl[0];
549 U b = bl[0];
550 result[row] +=
551 qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
552 }
553
554 w += block_size / pack_factor;
555 scales += block_size / group_size;
556 biases += block_size / group_size;
557 x += block_size;
558 }
559 const int remaining = clamp(
560 static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
561 0,
562 values_per_thread);
563 U sum =
564 load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
565
566 for (int row = 0; row < results_per_simdgroup; row++) {
567 const device uint8_t* wl =
568 (const device uint8_t*)(w + row * in_vec_size_w);
569 const device T* sl = scales + row * in_vec_size_g;
570 const device T* bl = biases + row * in_vec_size_g;
571
572 U s = sl[0];
573 U b = bl[0];
574 result[row] += qdot_safe<U, values_per_thread, bits>(
575 wl, x_thread, s, b, sum, remaining);
576 }
577
578 for (int row = 0; row < results_per_simdgroup; row++) {
579 result[row] = simd_sum(result[row]);
580 if (simd_lid == 0) {
581 y[row] = static_cast<T>(result[row]);
582 }
583 }
584 }
585}
586
587template <typename T, const int group_size, const int bits>
588METAL_FUNC void qvm_impl(
589 const device T* x,
590 const device uint32_t* w,
591 const device T* scales,
592 const device T* biases,
593 device T* y,
594 const constant int& in_vec_size,
595 const constant int& out_vec_size,
596 uint3 tid [[threadgroup_position_in_grid]],
597 uint simd_gid [[simdgroup_index_in_threadgroup]],
598 uint simd_lid [[thread_index_in_simdgroup]]) {
599 constexpr int num_simdgroups = 2;
600 constexpr int pack_factor = 32 / bits;
601 constexpr int tn = 32 / pack_factor;
602 constexpr int blocksize = SIMD_SIZE;
603
604 typedef float U;
605 typedef struct {
606 uint32_t wi[tn];
607 } vec_w;
608
609 thread vec_w w_local;
610 thread U result[tn * pack_factor] = {0};
611 thread U scale = 1;
612 thread U bias = 0;
613 thread U x_local = 0;
614
615 // Adjust positions
616 const int out_vec_size_w = out_vec_size / pack_factor;
617 const int out_vec_size_g = out_vec_size / group_size;
618 int out_col =
619 tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
620 w += out_col / pack_factor + simd_lid * out_vec_size_w;
621 scales += out_col / group_size + simd_lid * out_vec_size_g;
622 biases += out_col / group_size + simd_lid * out_vec_size_g;
623 x += tid.y * in_vec_size + simd_lid;
624 y += tid.y * out_vec_size + out_col;
625
626 if (out_col >= out_vec_size) {
627 return;
628 }
629
630 // Loop over in_vec in blocks of blocksize
631 int remaining = in_vec_size % blocksize;
632 if (remaining == 0) {
633 for (int i = 0; i < in_vec_size; i += blocksize) {
634 x_local = *x;
635 scale = *scales;
636 bias = *biases;
637 w_local = *((device vec_w*)w);
638
639 qouter<U, tn * pack_factor, bits>(
640 (thread uint8_t*)&w_local, x_local, scale, bias, result);
641
642 x += blocksize;
643 scales += blocksize * out_vec_size_g;
644 biases += blocksize * out_vec_size_g;
645 w += blocksize * out_vec_size_w;
646 }
647 } else {
648 for (int i = blocksize; i < in_vec_size; i += blocksize) {
649 x_local = *x;
650 scale = *scales;
651 bias = *biases;
652 w_local = *((device vec_w*)w);
653
654 qouter<U, tn * pack_factor, bits>(
655 (thread uint8_t*)&w_local, x_local, scale, bias, result);
656
657 x += blocksize;
658 scales += blocksize * out_vec_size_g;
659 biases += blocksize * out_vec_size_g;
660 w += blocksize * out_vec_size_w;
661 }
662 if (static_cast<int>(simd_lid) < remaining) {
663 x_local = *x;
664 scale = *scales;
665 bias = *biases;
666 w_local = *((device vec_w*)w);
667 } else {
668 x_local = 0;
669 scale = 0;
670 bias = 0;
671 }
672 qouter<U, tn * pack_factor, bits>(
673 (thread uint8_t*)&w_local, x_local, scale, bias, result);
674 }
675
676// Accumulate in the simdgroup
677#pragma clang loop unroll(full)
678 for (int k = 0; k < tn * pack_factor; k++) {
679 result[k] = simd_sum(result[k]);
680 }
681
682 // Store the result
683 if (simd_lid == 0) {
684#pragma clang loop unroll(full)
685 for (int k = 0; k < tn * pack_factor; k++) {
686 y[k] = static_cast<T>(result[k]);
687 }
688 }
689}
690
691template <
692 typename T,
693 const int group_size,
694 const int bits,
695 const bool aligned_N,
696 const int BM = 32,
697 const int BK = 32,
698 const int BN = 32>
699METAL_FUNC void qmm_t_impl(
700 const device T* x,
701 const device uint32_t* w,
702 const device T* scales,
703 const device T* biases,
704 device T* y,
705 threadgroup T* Xs,
706 threadgroup T* Ws,
707 const constant int& M,
708 const constant int& N,
709 const constant int& K,
710 uint3 tid [[threadgroup_position_in_grid]],
711 uint lid [[thread_index_in_threadgroup]],
712 uint simd_gid [[simdgroup_index_in_threadgroup]],
713 uint simd_lid [[thread_index_in_simdgroup]]) {
714 static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
715 static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
716
717 (void)lid;
718
719 constexpr int WM = 2;
720 constexpr int WN = 2;
721 constexpr int pack_factor = 32 / bits;
722 constexpr int BK_padded = (BK + 16 / sizeof(T));
723
724 // Instantiate the appropriate BlockMMA and Loader
725 using mma_t = mlx::steel::
726 BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
727 using loader_x_t =
729 using loader_w_t = QuantizedBlockLoader<
730 T,
731 BN,
732 BK,
733 BK_padded,
734 1,
735 WM * WN * SIMD_SIZE,
736 group_size,
737 bits>;
738
739 // Set the block
740 const int K_w = K / pack_factor;
741 const int K_g = K / group_size;
742 const int y_row = tid.y * BM;
743 const int y_col = tid.x * BN;
744
745 x += y_row * K;
746 w += y_col * K_w;
747 scales += y_col * K_g;
748 biases += y_col * K_g;
749 y += y_row * N + y_col;
750
751 // Make the x loader and mma operation
752 const short num_els = min(BM, M - y_row);
753 const short num_outs = min(BN, N - y_col);
754 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
755 loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid);
756 mma_t mma_op(simd_gid, simd_lid);
757
758 if (num_els < BM) {
759 if (!aligned_N && num_outs < BN) {
760 for (int k = 0; k < K; k += BK) {
761 threadgroup_barrier(mem_flags::mem_threadgroup);
762 loader_x.load_safe(short2(BK, num_els));
763 loader_w.load_safe(short2(BK, num_outs));
764 threadgroup_barrier(mem_flags::mem_threadgroup);
765 mma_op.mma(Xs, Ws);
766 loader_x.next();
767 loader_w.next();
768 }
769 } else {
770 for (int k = 0; k < K; k += BK) {
771 threadgroup_barrier(mem_flags::mem_threadgroup);
772 loader_x.load_safe(short2(BK, num_els));
773 loader_w.load_unsafe();
774 threadgroup_barrier(mem_flags::mem_threadgroup);
775 mma_op.mma(Xs, Ws);
776 loader_x.next();
777 loader_w.next();
778 }
779 }
780 } else {
781 if (!aligned_N && num_outs < BN) {
782 for (int k = 0; k < K; k += BK) {
783 threadgroup_barrier(mem_flags::mem_threadgroup);
784 loader_x.load_unsafe();
785 loader_w.load_safe(short2(BK, num_outs));
786 threadgroup_barrier(mem_flags::mem_threadgroup);
787 mma_op.mma(Xs, Ws);
788 loader_x.next();
789 loader_w.next();
790 }
791 } else {
792 for (int k = 0; k < K; k += BK) {
793 threadgroup_barrier(mem_flags::mem_threadgroup);
794 loader_x.load_unsafe();
795 loader_w.load_unsafe();
796 threadgroup_barrier(mem_flags::mem_threadgroup);
797 mma_op.mma(Xs, Ws);
798 loader_x.next();
799 loader_w.next();
800 }
801 }
802 }
803
804 // Store results to device memory
805 threadgroup_barrier(mem_flags::mem_threadgroup);
806 if (num_els < BM || num_outs < BN) {
807 mma_op.store_result_safe(y, N, short2(num_outs, num_els));
808 } else {
809 mma_op.store_result(y, N);
810 }
811}
812
813template <
814 typename T,
815 const int group_size,
816 const int bits,
817 const int BM = 32,
818 const int BK = 32,
819 const int BN = 32>
820METAL_FUNC void qmm_n_impl(
821 const device T* x,
822 const device uint32_t* w,
823 const device T* scales,
824 const device T* biases,
825 device T* y,
826 threadgroup T* Xs,
827 threadgroup T* Ws,
828 const constant int& M,
829 const constant int& N,
830 const constant int& K,
831 uint3 tid [[threadgroup_position_in_grid]],
832 uint lid [[thread_index_in_threadgroup]],
833 uint simd_gid [[simdgroup_index_in_threadgroup]],
834 uint simd_lid [[thread_index_in_simdgroup]]) {
835 static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
836 static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
837
838 (void)lid;
839
840 constexpr int WM = 2;
841 constexpr int WN = 2;
842 constexpr int pack_factor = 32 / bits;
843 constexpr int BK_padded = (BK + 16 / sizeof(T));
844 constexpr int BN_padded = (BN + 16 / sizeof(T));
845
846 // Instantiate the appropriate BlockMMA and Loader
847 using mma_t = mlx::steel::
848 BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
849 using loader_x_t = mlx::steel::
850 BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
851 using loader_w_t = QuantizedBlockLoader<
852 T,
853 BK,
854 BN,
855 BN_padded,
856 0,
857 WM * WN * SIMD_SIZE,
858 group_size,
859 bits>;
860
861 // Set the block
862 const int y_row = tid.y * BM;
863 const int y_col = tid.x * BN;
864 x += y_row * K;
865 w += y_col / pack_factor;
866 scales += y_col / group_size;
867 biases += y_col / group_size;
868 y += y_row * N + y_col;
869
870 // Make the x loader and mma operation
871 const short num_els = min(BM, M - y_row);
872 loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
873 loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid);
874 mma_t mma_op(simd_gid, simd_lid);
875
876 if (num_els < BM) {
877 if ((K % BK) != 0) {
878 const int k_blocks = K / BK;
879 for (int k = 0; k < k_blocks; k++) {
880 threadgroup_barrier(mem_flags::mem_threadgroup);
881 loader_x.load_safe(short2(BK, num_els));
882 loader_w.load_unsafe();
883 threadgroup_barrier(mem_flags::mem_threadgroup);
884 mma_op.mma(Xs, Ws);
885 loader_x.next();
886 loader_w.next();
887 }
888 const short num_k = K - k_blocks * BK;
889 threadgroup_barrier(mem_flags::mem_threadgroup);
890 loader_x.load_safe(short2(num_k, num_els));
891 loader_w.load_safe(short2(BN, num_k));
892 threadgroup_barrier(mem_flags::mem_threadgroup);
893 mma_op.mma(Xs, Ws);
894 } else {
895 for (int k = 0; k < K; k += BK) {
896 threadgroup_barrier(mem_flags::mem_threadgroup);
897 loader_x.load_safe(short2(BK, num_els));
898 loader_w.load_unsafe();
899 threadgroup_barrier(mem_flags::mem_threadgroup);
900 mma_op.mma(Xs, Ws);
901 loader_x.next();
902 loader_w.next();
903 }
904 }
905 } else {
906 if ((K % BK) != 0) {
907 const int k_blocks = K / BK;
908 for (int k = 0; k < k_blocks; k++) {
909 threadgroup_barrier(mem_flags::mem_threadgroup);
910 loader_x.load_unsafe();
911 loader_w.load_unsafe();
912 threadgroup_barrier(mem_flags::mem_threadgroup);
913 mma_op.mma(Xs, Ws);
914 loader_x.next();
915 loader_w.next();
916 }
917 const short num_k = K - k_blocks * BK;
918 threadgroup_barrier(mem_flags::mem_threadgroup);
919 loader_x.load_safe(short2(num_k, BM));
920 loader_w.load_safe(short2(BN, num_k));
921 threadgroup_barrier(mem_flags::mem_threadgroup);
922 mma_op.mma(Xs, Ws);
923 } else {
924 for (int k = 0; k < K; k += BK) {
925 threadgroup_barrier(mem_flags::mem_threadgroup);
926 loader_x.load_unsafe();
927 loader_w.load_unsafe();
928 threadgroup_barrier(mem_flags::mem_threadgroup);
929 mma_op.mma(Xs, Ws);
930 loader_x.next();
931 loader_w.next();
932 }
933 }
934 }
935
936 // Store results to device memory
937 threadgroup_barrier(mem_flags::mem_threadgroup);
938 if (num_els < BM) {
939 mma_op.store_result_safe(y, N, short2(BN, num_els));
940 } else {
941 mma_op.store_result(y, N);
942 }
943}
944
945template <typename T>
946METAL_FUNC void adjust_matrix_offsets(
947 const device T*& x,
948 const device uint32_t*& w,
949 const device T*& scales,
950 const device T*& biases,
951 const device uint32_t* lhs_indices,
952 const device uint32_t* rhs_indices,
953 device T*& y,
954 int output_stride,
955 const constant int& batch_ndims,
956 const constant int* batch_shape,
957 const constant size_t* lhs_strides,
958 const constant size_t* rhs_strides,
959 const constant int& x_batch_ndims,
960 const constant int* x_shape,
961 const constant size_t* x_strides,
962 const constant int& w_batch_ndims,
963 const constant int* w_shape,
964 const constant size_t* w_strides,
965 const constant size_t* s_strides,
966 const constant size_t* b_strides,
967 uint3 tid [[threadgroup_position_in_grid]]) {
968 // Set the input/output matrices
969 uint32_t x_idx;
970 uint32_t w_idx;
971 if (batch_ndims == 1) {
972 x_idx = lhs_indices[tid.z * lhs_strides[0]];
973 w_idx = rhs_indices[tid.z * rhs_strides[0]];
974 } else {
975 ulong2 idx = elem_to_loc_broadcast(
976 tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
977 x_idx = lhs_indices[idx.x];
978 w_idx = rhs_indices[idx.y];
979 }
980 if (x_batch_ndims == 1) {
981 x += x_idx * x_strides[0];
982 } else {
983 x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
984 }
985 if (w_batch_ndims == 1) {
986 w += w_idx * w_strides[0];
987 scales += w_idx * s_strides[0];
988 biases += w_idx * b_strides[0];
989 } else {
990 ulong3 idx = elem_to_loc_broadcast(
991 w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
992 w += idx.x;
993 scales += idx.y;
994 biases += idx.z;
995 }
996 y += tid.z * output_stride;
997}
998
999template <typename T, int group_size, int bits>
1000[[kernel]] void qmv_fast(
1001 const device uint32_t* w [[buffer(0)]],
1002 const device T* scales [[buffer(1)]],
1003 const device T* biases [[buffer(2)]],
1004 const device T* x [[buffer(3)]],
1005 device T* y [[buffer(4)]],
1006 const constant int& in_vec_size [[buffer(5)]],
1007 const constant int& out_vec_size [[buffer(6)]],
1008 uint3 tid [[threadgroup_position_in_grid]],
1009 uint simd_gid [[simdgroup_index_in_threadgroup]],
1010 uint simd_lid [[thread_index_in_simdgroup]]) {
1011 qmv_fast_impl<T, group_size, bits>(
1012 w,
1013 scales,
1014 biases,
1015 x,
1016 y,
1017 in_vec_size,
1018 out_vec_size,
1019 tid,
1020 simd_gid,
1021 simd_lid);
1022}
1023
1024template <typename T, const int group_size, const int bits>
1025[[kernel]] void qmv(
1026 const device uint32_t* w [[buffer(0)]],
1027 const device T* scales [[buffer(1)]],
1028 const device T* biases [[buffer(2)]],
1029 const device T* x [[buffer(3)]],
1030 device T* y [[buffer(4)]],
1031 const constant int& in_vec_size [[buffer(5)]],
1032 const constant int& out_vec_size [[buffer(6)]],
1033 uint3 tid [[threadgroup_position_in_grid]],
1034 uint simd_gid [[simdgroup_index_in_threadgroup]],
1035 uint simd_lid [[thread_index_in_simdgroup]]) {
1036 qmv_impl<T, group_size, bits>(
1037 w,
1038 scales,
1039 biases,
1040 x,
1041 y,
1042 in_vec_size,
1043 out_vec_size,
1044 tid,
1045 simd_gid,
1046 simd_lid);
1047}
1048
1049template <typename T, const int group_size, const int bits>
1050[[kernel]] void qvm(
1051 const device T* x [[buffer(0)]],
1052 const device uint32_t* w [[buffer(1)]],
1053 const device T* scales [[buffer(2)]],
1054 const device T* biases [[buffer(3)]],
1055 device T* y [[buffer(4)]],
1056 const constant int& in_vec_size [[buffer(5)]],
1057 const constant int& out_vec_size [[buffer(6)]],
1058 uint3 tid [[threadgroup_position_in_grid]],
1059 uint simd_gid [[simdgroup_index_in_threadgroup]],
1060 uint simd_lid [[thread_index_in_simdgroup]]) {
1061 qvm_impl<T, group_size, bits>(
1062 x,
1063 w,
1064 scales,
1065 biases,
1066 y,
1067 in_vec_size,
1068 out_vec_size,
1069 tid,
1070 simd_gid,
1071 simd_lid);
1072}
1073
1074template <
1075 typename T,
1076 const int group_size,
1077 const int bits,
1078 const bool aligned_N,
1079 const int BM = 32,
1080 const int BK = 32,
1081 const int BN = 32>
1082[[kernel]] void qmm_t(
1083 const device T* x [[buffer(0)]],
1084 const device uint32_t* w [[buffer(1)]],
1085 const device T* scales [[buffer(2)]],
1086 const device T* biases [[buffer(3)]],
1087 device T* y [[buffer(4)]],
1088 const constant int& M [[buffer(5)]],
1089 const constant int& N [[buffer(6)]],
1090 const constant int& K [[buffer(7)]],
1091 uint3 tid [[threadgroup_position_in_grid]],
1092 uint lid [[thread_index_in_threadgroup]],
1093 uint simd_gid [[simdgroup_index_in_threadgroup]],
1094 uint simd_lid [[thread_index_in_simdgroup]]) {
1095 (void)lid;
1096
1097 constexpr int BK_padded = (BK + 16 / sizeof(T));
1098
1099 threadgroup T Xs[BM * BK_padded];
1100 threadgroup T Ws[BN * BK_padded];
1101
1102 qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
1103 x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
1104}
1105
1106template <
1107 typename T,
1108 const int group_size,
1109 const int bits,
1110 const int BM = 32,
1111 const int BK = 32,
1112 const int BN = 32>
1113[[kernel]] void qmm_n(
1114 const device T* x [[buffer(0)]],
1115 const device uint32_t* w [[buffer(1)]],
1116 const device T* scales [[buffer(2)]],
1117 const device T* biases [[buffer(3)]],
1118 device T* y [[buffer(4)]],
1119 const constant int& M [[buffer(5)]],
1120 const constant int& N [[buffer(6)]],
1121 const constant int& K [[buffer(7)]],
1122 uint3 tid [[threadgroup_position_in_grid]],
1123 uint lid [[thread_index_in_threadgroup]],
1124 uint simd_gid [[simdgroup_index_in_threadgroup]],
1125 uint simd_lid [[thread_index_in_simdgroup]]) {
1126 (void)lid;
1127
1128 constexpr int BK_padded = (BK + 16 / sizeof(T));
1129 constexpr int BN_padded = (BN + 16 / sizeof(T));
1130
1131 threadgroup T Xs[BM * BK_padded];
1132 threadgroup T Ws[BK * BN_padded];
1133
1134 qmm_n_impl<T, group_size, bits, BM, BK, BN>(
1135 x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
1136}
1137
1138template <typename T, int group_size, int bits>
1139[[kernel]] void bs_qmv_fast(
1140 const device uint32_t* w [[buffer(0)]],
1141 const device T* scales [[buffer(1)]],
1142 const device T* biases [[buffer(2)]],
1143 const device T* x [[buffer(3)]],
1144 const device uint32_t* lhs_indices [[buffer(4)]],
1145 const device uint32_t* rhs_indices [[buffer(5)]],
1146 device T* y [[buffer(6)]],
1147 const constant int& in_vec_size [[buffer(7)]],
1148 const constant int& out_vec_size [[buffer(8)]],
1149 const constant int& batch_ndims [[buffer(9)]],
1150 const constant int* batch_shape [[buffer(10)]],
1151 const constant size_t* lhs_strides [[buffer(11)]],
1152 const constant size_t* rhs_strides [[buffer(12)]],
1153 const constant int& x_batch_ndims [[buffer(13)]],
1154 const constant int* x_shape [[buffer(14)]],
1155 const constant size_t* x_strides [[buffer(15)]],
1156 const constant int& w_batch_ndims [[buffer(16)]],
1157 const constant int* w_shape [[buffer(17)]],
1158 const constant size_t* w_strides [[buffer(18)]],
1159 const constant size_t* s_strides [[buffer(19)]],
1160 const constant size_t* b_strides [[buffer(20)]],
1161 uint3 tid [[threadgroup_position_in_grid]],
1162 uint simd_gid [[simdgroup_index_in_threadgroup]],
1163 uint simd_lid [[thread_index_in_simdgroup]]) {
1164 adjust_matrix_offsets<T>(
1165 x,
1166 w,
1167 scales,
1168 biases,
1169 lhs_indices,
1170 rhs_indices,
1171 y,
1172 out_vec_size,
1173 batch_ndims,
1174 batch_shape,
1175 lhs_strides,
1176 rhs_strides,
1177 x_batch_ndims,
1178 x_shape,
1179 x_strides,
1180 w_batch_ndims,
1181 w_shape,
1182 w_strides,
1183 s_strides,
1184 b_strides,
1185 tid);
1186 qmv_fast_impl<T, group_size, bits>(
1187 w,
1188 scales,
1189 biases,
1190 x,
1191 y,
1192 in_vec_size,
1193 out_vec_size,
1194 tid,
1195 simd_gid,
1196 simd_lid);
1197}
1198
1199template <typename T, int group_size, int bits>
1200[[kernel]] void bs_qmv(
1201 const device uint32_t* w [[buffer(0)]],
1202 const device T* scales [[buffer(1)]],
1203 const device T* biases [[buffer(2)]],
1204 const device T* x [[buffer(3)]],
1205 const device uint32_t* lhs_indices [[buffer(4)]],
1206 const device uint32_t* rhs_indices [[buffer(5)]],
1207 device T* y [[buffer(6)]],
1208 const constant int& in_vec_size [[buffer(7)]],
1209 const constant int& out_vec_size [[buffer(8)]],
1210 const constant int& batch_ndims [[buffer(9)]],
1211 const constant int* batch_shape [[buffer(10)]],
1212 const constant size_t* lhs_strides [[buffer(11)]],
1213 const constant size_t* rhs_strides [[buffer(12)]],
1214 const constant int& x_batch_ndims [[buffer(13)]],
1215 const constant int* x_shape [[buffer(14)]],
1216 const constant size_t* x_strides [[buffer(15)]],
1217 const constant int& w_batch_ndims [[buffer(16)]],
1218 const constant int* w_shape [[buffer(17)]],
1219 const constant size_t* w_strides [[buffer(18)]],
1220 const constant size_t* s_strides [[buffer(19)]],
1221 const constant size_t* b_strides [[buffer(20)]],
1222 uint3 tid [[threadgroup_position_in_grid]],
1223 uint simd_gid [[simdgroup_index_in_threadgroup]],
1224 uint simd_lid [[thread_index_in_simdgroup]]) {
1225 adjust_matrix_offsets<T>(
1226 x,
1227 w,
1228 scales,
1229 biases,
1230 lhs_indices,
1231 rhs_indices,
1232 y,
1233 out_vec_size,
1234 batch_ndims,
1235 batch_shape,
1236 lhs_strides,
1237 rhs_strides,
1238 x_batch_ndims,
1239 x_shape,
1240 x_strides,
1241 w_batch_ndims,
1242 w_shape,
1243 w_strides,
1244 s_strides,
1245 b_strides,
1246 tid);
1247 qmv_impl<T, group_size, bits>(
1248 w,
1249 scales,
1250 biases,
1251 x,
1252 y,
1253 in_vec_size,
1254 out_vec_size,
1255 tid,
1256 simd_gid,
1257 simd_lid);
1258}
1259
1260template <typename T, int group_size, int bits>
1261[[kernel]] void bs_qvm(
1262 const device T* x [[buffer(0)]],
1263 const device uint32_t* w [[buffer(1)]],
1264 const device T* scales [[buffer(2)]],
1265 const device T* biases [[buffer(3)]],
1266 const device uint32_t* lhs_indices [[buffer(4)]],
1267 const device uint32_t* rhs_indices [[buffer(5)]],
1268 device T* y [[buffer(6)]],
1269 const constant int& in_vec_size [[buffer(7)]],
1270 const constant int& out_vec_size [[buffer(8)]],
1271 const constant int& batch_ndims [[buffer(9)]],
1272 const constant int* batch_shape [[buffer(10)]],
1273 const constant size_t* lhs_strides [[buffer(11)]],
1274 const constant size_t* rhs_strides [[buffer(12)]],
1275 const constant int& x_batch_ndims [[buffer(13)]],
1276 const constant int* x_shape [[buffer(14)]],
1277 const constant size_t* x_strides [[buffer(15)]],
1278 const constant int& w_batch_ndims [[buffer(16)]],
1279 const constant int* w_shape [[buffer(17)]],
1280 const constant size_t* w_strides [[buffer(18)]],
1281 const constant size_t* s_strides [[buffer(19)]],
1282 const constant size_t* b_strides [[buffer(20)]],
1283 uint3 tid [[threadgroup_position_in_grid]],
1284 uint simd_gid [[simdgroup_index_in_threadgroup]],
1285 uint simd_lid [[thread_index_in_simdgroup]]) {
1286 adjust_matrix_offsets<T>(
1287 x,
1288 w,
1289 scales,
1290 biases,
1291 lhs_indices,
1292 rhs_indices,
1293 y,
1294 out_vec_size,
1295 batch_ndims,
1296 batch_shape,
1297 lhs_strides,
1298 rhs_strides,
1299 x_batch_ndims,
1300 x_shape,
1301 x_strides,
1302 w_batch_ndims,
1303 w_shape,
1304 w_strides,
1305 s_strides,
1306 b_strides,
1307 tid);
1308 qvm_impl<T, group_size, bits>(
1309 x,
1310 w,
1311 scales,
1312 biases,
1313 y,
1314 in_vec_size,
1315 out_vec_size,
1316 tid,
1317 simd_gid,
1318 simd_lid);
1319}
1320
1321template <
1322 typename T,
1323 const int group_size,
1324 const int bits,
1325 const bool aligned_N,
1326 const int BM = 32,
1327 const int BK = 32,
1328 const int BN = 32>
1329[[kernel]] void bs_qmm_t(
1330 const device T* x [[buffer(0)]],
1331 const device uint32_t* w [[buffer(1)]],
1332 const device T* scales [[buffer(2)]],
1333 const device T* biases [[buffer(3)]],
1334 const device uint32_t* lhs_indices [[buffer(4)]],
1335 const device uint32_t* rhs_indices [[buffer(5)]],
1336 device T* y [[buffer(6)]],
1337 const constant int& M [[buffer(7)]],
1338 const constant int& N [[buffer(8)]],
1339 const constant int& K [[buffer(9)]],
1340 const constant int& batch_ndims [[buffer(10)]],
1341 const constant int* batch_shape [[buffer(11)]],
1342 const constant size_t* lhs_strides [[buffer(12)]],
1343 const constant size_t* rhs_strides [[buffer(13)]],
1344 const constant int& x_batch_ndims [[buffer(14)]],
1345 const constant int* x_shape [[buffer(15)]],
1346 const constant size_t* x_strides [[buffer(16)]],
1347 const constant int& w_batch_ndims [[buffer(17)]],
1348 const constant int* w_shape [[buffer(18)]],
1349 const constant size_t* w_strides [[buffer(19)]],
1350 const constant size_t* s_strides [[buffer(20)]],
1351 const constant size_t* b_strides [[buffer(21)]],
1352 uint3 tid [[threadgroup_position_in_grid]],
1353 uint lid [[thread_index_in_threadgroup]],
1354 uint simd_gid [[simdgroup_index_in_threadgroup]],
1355 uint simd_lid [[thread_index_in_simdgroup]]) {
1356 (void)lid;
1357
1358 constexpr int BK_padded = (BK + 16 / sizeof(T));
1359
1360 threadgroup T Xs[BM * BK_padded];
1361 threadgroup T Ws[BN * BK_padded];
1362
1363 adjust_matrix_offsets<T>(
1364 x,
1365 w,
1366 scales,
1367 biases,
1368 lhs_indices,
1369 rhs_indices,
1370 y,
1371 M * N,
1372 batch_ndims,
1373 batch_shape,
1374 lhs_strides,
1375 rhs_strides,
1376 x_batch_ndims,
1377 x_shape,
1378 x_strides,
1379 w_batch_ndims,
1380 w_shape,
1381 w_strides,
1382 s_strides,
1383 b_strides,
1384 tid);
1385 qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
1386 x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
1387}
1388
1389template <
1390 typename T,
1391 const int group_size,
1392 const int bits,
1393 const int BM = 32,
1394 const int BK = 32,
1395 const int BN = 32>
1396[[kernel]] void bs_qmm_n(
1397 const device T* x [[buffer(0)]],
1398 const device uint32_t* w [[buffer(1)]],
1399 const device T* scales [[buffer(2)]],
1400 const device T* biases [[buffer(3)]],
1401 const device uint32_t* lhs_indices [[buffer(4)]],
1402 const device uint32_t* rhs_indices [[buffer(5)]],
1403 device T* y [[buffer(6)]],
1404 const constant int& M [[buffer(7)]],
1405 const constant int& N [[buffer(8)]],
1406 const constant int& K [[buffer(9)]],
1407 const constant int& batch_ndims [[buffer(10)]],
1408 const constant int* batch_shape [[buffer(11)]],
1409 const constant size_t* lhs_strides [[buffer(12)]],
1410 const constant size_t* rhs_strides [[buffer(13)]],
1411 const constant int& x_batch_ndims [[buffer(14)]],
1412 const constant int* x_shape [[buffer(15)]],
1413 const constant size_t* x_strides [[buffer(16)]],
1414 const constant int& w_batch_ndims [[buffer(17)]],
1415 const constant int* w_shape [[buffer(18)]],
1416 const constant size_t* w_strides [[buffer(19)]],
1417 const constant size_t* s_strides [[buffer(20)]],
1418 const constant size_t* b_strides [[buffer(21)]],
1419 uint3 tid [[threadgroup_position_in_grid]],
1420 uint lid [[thread_index_in_threadgroup]],
1421 uint simd_gid [[simdgroup_index_in_threadgroup]],
1422 uint simd_lid [[thread_index_in_simdgroup]]) {
1423 (void)lid;
1424
1425 constexpr int BK_padded = (BK + 16 / sizeof(T));
1426 constexpr int BN_padded = (BN + 16 / sizeof(T));
1427
1428 threadgroup T Xs[BM * BK_padded];
1429 threadgroup T Ws[BK * BN_padded];
1430
1431 adjust_matrix_offsets<T>(
1432 x,
1433 w,
1434 scales,
1435 biases,
1436 lhs_indices,
1437 rhs_indices,
1438 y,
1439 M * N,
1440 batch_ndims,
1441 batch_shape,
1442 lhs_strides,
1443 rhs_strides,
1444 x_batch_ndims,
1445 x_shape,
1446 x_strides,
1447 w_batch_ndims,
1448 w_shape,
1449 w_strides,
1450 s_strides,
1451 b_strides,
1452 tid);
1453 qmm_n_impl<T, group_size, bits, BM, BK, BN>(
1454 x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
1455}
1456
1457template <typename T, const int group_size, const int bits>
1458[[kernel]] void affine_quantize(
1459 const device T* w [[buffer(0)]],
1460 device uint8_t* out [[buffer(1)]],
1461 device T* scales [[buffer(2)]],
1462 device T* biases [[buffer(3)]],
1463 uint index [[thread_position_in_grid]]) {
1464 constexpr T eps = T(1e-7);
1465 constexpr int simd_size = 32;
1466 constexpr int uint8_bits = 8;
1467 constexpr T n_bins = (1 << bits) - 1;
1468 constexpr int packs_per_int = uint8_bits / bits;
1469 constexpr int values_per_reduce = group_size / simd_size;
1470 constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
1471 constexpr int writes_per_pack =
1472 writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
1473
1474 static_assert(
1475 group_size % simd_size == 0,
1476 "Group size must be divisible by simd size.");
1477
1478 int in_index = index * values_per_reduce;
1479 int out_index = index * writes_per_pack;
1480
1481 T w_thread[values_per_reduce];
1482 T w_min = Limits<T>::max;
1483 T w_max = 0;
1484
1485#pragma clang loop unroll(full)
1486 for (int i = 0; i < values_per_reduce; i++) {
1487 T val = w[in_index + i];
1488 w_thread[i] = val;
1489 w_min = min(w_min, val);
1490 w_max = max(w_max, val);
1491 }
1492
1493 w_min = simd_min(w_min);
1494 w_max = simd_max(w_max);
1495
1496 T scale = max((w_max - w_min) / n_bins, eps);
1497 bool side = abs(w_min) > abs(w_max);
1498 scale = side ? scale : -scale;
1499 T edge = side ? w_min : w_max;
1500 T q0 = round(edge / scale);
1501 bool at_zero = q0 == 0.0f;
1502 scale = at_zero ? scale : edge / q0;
1503 T bias = at_zero ? T(0) : edge;
1504
1505 // Write out the scales and biases
1506 int gindex = in_index / group_size;
1507 if (in_index % group_size == 0) {
1508 scales[gindex] = scale;
1509 biases[gindex] = bias;
1510 }
1511
1512 uint8_t output = 0;
1513#pragma clang loop unroll(full)
1514 for (int i = 0; i < values_per_reduce; i++) {
1515 uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
1516 if (bits == 8) {
1517 output = val;
1518 } else {
1519 output += val << (bits * (i % packs_per_int));
1520 }
1521
1522 if (packs_per_int < values_per_reduce &&
1523 i % packs_per_int == packs_per_int - 1) {
1524 out[out_index + i / packs_per_int] = output;
1525 output = 0;
1526 } else {
1527#pragma clang loop unroll(full)
1528 for (int j = 0; j < writes_per_reduce - 1; j++) {
1529 uint8_t sval = simd_shuffle_down(val, j + 1);
1530 output += sval << (bits * (values_per_reduce + j + i));
1531 }
1532 }
1533 }
1534 if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
1535 out[out_index / writes_per_reduce] = output;
1536 }
1537}
1538
1539template <typename T, const int group_size, const int bits>
1541 const device T* w [[buffer(0)]],
1542 const device T* scales [[buffer(1)]],
1543 const device T* biases [[buffer(2)]],
1544 device uint8_t* out [[buffer(3)]],
1545 uint index [[thread_position_in_grid]]) {
1546 constexpr int uint8_bits = 8;
1547 constexpr int packs_per_int = uint8_bits / bits;
1548 constexpr T n_bins = (1 << bits) - 1;
1549
1550 int in_index = index * packs_per_int;
1551 int gindex = in_index / group_size;
1552 T scale = scales[gindex];
1553 T bias = biases[gindex];
1554
1555 uint8_t output = 0;
1556#pragma clang loop unroll(full)
1557 for (int i = 0; i < packs_per_int; i++) {
1558 uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins);
1559 if (bits == 8) {
1560 output = val;
1561 } else {
1562 output += val << (bits * i);
1563 }
1564 }
1565 out[index] = output;
1566}
1567
1568template <typename T, const int group_size, const int bits>
1569[[kernel]] void affine_dequantize(
1570 const device uint8_t* w [[buffer(0)]],
1571 const device T* scales [[buffer(1)]],
1572 const device T* biases [[buffer(2)]],
1573 device T* out [[buffer(3)]],
1574 uint index [[thread_position_in_grid]]) {
1575 constexpr int uint8_bits = 8;
1576 constexpr int packs_per_int = uint8_bits / bits;
1577
1578 int oindex = index * packs_per_int;
1579 int gindex = oindex / group_size;
1580 T scale = scales[gindex];
1581 T bias = biases[gindex];
1582 uint val = w[index];
1583
1584#pragma clang loop unroll(full)
1585 for (int i = 0; i < packs_per_int; i++) {
1586 uint8_t d;
1587 if (bits == 2) {
1588 d = (val >> (bits * i)) & 0x03;
1589 } else if (bits == 4) {
1590 d = (val >> (bits * i)) & 0x0f;
1591 } else if (bits == 8) {
1592 d = val;
1593 }
1594 out[oindex + i] = scale * d + bias;
1595 }
1596}
static constant constexpr const uint8_t simd_size
Definition ops.h:22
METAL_FUNC ulong2 elem_to_loc_broadcast(uint elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
Definition utils.h:7
METAL_FUNC stride_t elem_to_loc(uint elem, device const int *shape, device const stride_t *strides, int ndim)
Definition utils.h:87
Definition bf16.h:265
METAL_FUNC bfloat16_t simd_max(bfloat16_t data)
Definition bf16_math.h:392
METAL_FUNC bfloat16_t round(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t simd_sum(bfloat16_t data)
Definition bf16_math.h:392
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t abs(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t simd_min(bfloat16_t data)
Definition bf16_math.h:392
METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta)
Definition bf16_math.h:391
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
#define MLX_MTL_CONST
Definition quantized.h:8
void bs_qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1200
U qdot_safe(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum, int N)
Definition quantized.h:141
void qvm(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1050
METAL_FUNC void adjust_matrix_offsets(const device T *&x, const device uint32_t *&w, const device T *&scales, const device T *&biases, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *&y, int output_stride, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid)
Definition quantized.h:946
void bs_qmm_t(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *y, const constant int &M, const constant int &N, const constant int &K, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1329
void bs_qvm(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1261
void affine_quantize(const device T *w, device uint8_t *out, device T *scales, device T *biases, uint index)
Definition quantized.h:1458
void affine_quantize_scales_biases(const device T *w, const device T *scales, const device T *biases, device uint8_t *out, uint index)
Definition quantized.h:1540
METAL_FUNC void qmm_n_impl(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &M, const constant int &N, const constant int &K, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:820
static constant constexpr const int SIMD_SIZE
Definition quantized.h:10
void bs_qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1139
U load_vector(const device T *x, thread U *x_thread)
Definition quantized.h:13
METAL_FUNC void qmv_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:439
METAL_FUNC void qvm_impl(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:588
U load_vector_safe(const device T *x, thread U *x_thread, int N)
Definition quantized.h:51
U qdot(const device uint8_t *w, const thread U *x_thread, U scale, U bias, U sum)
Definition quantized.h:98
void qmm_n(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, const constant int &M, const constant int &N, const constant int &K, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1113
METAL_FUNC void qmv_fast_impl(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:375
METAL_FUNC void qmm_t_impl(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &M, const constant int &N, const constant int &K, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:699
void bs_qmm_n(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, const device uint32_t *lhs_indices, const device uint32_t *rhs_indices, device T *y, const constant int &M, const constant int &N, const constant int &K, const constant int &batch_ndims, const constant int *batch_shape, const constant size_t *lhs_strides, const constant size_t *rhs_strides, const constant int &x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1396
void qmv_fast(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1000
void qmv(const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &in_vec_size, const constant int &out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)
Definition quantized.h:1025
void qouter(const thread uint8_t *w, U x, U scale, U bias, thread U *result)
Definition quantized.h:186
void affine_dequantize(const device uint8_t *w, const device T *scales, const device T *biases, device T *out, uint index)
Definition quantized.h:1569
void dequantize(const device uint8_t *w, U scale, U bias, threadgroup U *w_local)
Definition quantized.h:218
void qmm_t(const device T *x, const device uint32_t *w, const device T *scales, const device T *biases, device T *y, const constant int &M, const constant int &N, const constant int &K, uint3 tid, uint lid, uint simd_gid, uint simd_lid)
Definition quantized.h:1082
Definition utils.h:17
Definition quantized.h:261
const int group_stride
Definition quantized.h:281
static constant constexpr const short BCOLS_PACKED
Definition quantized.h:273
const device T * biases
Definition quantized.h:290
short group_step_cnt
Definition quantized.h:280
static constant constexpr const short group_steps
Definition quantized.h:276
const short thread_idx
Definition quantized.h:283
const device T * scales
Definition quantized.h:289
static constant constexpr const short n_reads
Definition quantized.h:274
void next()
Definition quantized.h:353
void load_safe(short2 src_tile_dim) const
Definition quantized.h:326
const int src_ld
Definition quantized.h:278
const short bi
Definition quantized.h:284
void load_unsafe() const
Definition quantized.h:313
static constant constexpr const short pack_factor
Definition quantized.h:272
threadgroup T * dst
Definition quantized.h:287
const int tile_stride
Definition quantized.h:279
const device uint32_t * src
Definition quantized.h:288
const short bj
Definition quantized.h:285
QuantizedBlockLoader(const device uint32_t *src_, const device T *scales_, const device T *biases_, const int src_ld_, threadgroup T *dst_, ushort simd_group_id, ushort simd_lane_id)
Definition quantized.h:292
Definition loader.h:25