MLX
Loading...
Searching...
No Matches
gemv_masked.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
4
5using namespace metal;
6
7#define MLX_MTL_CONST static constant constexpr const
8#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
9
10struct _NoMask {
11 char x;
12
13 constexpr METAL_FUNC operator bool() {
14 return true;
15 }
16 constexpr METAL_FUNC operator bool() const threadgroup {
17 return true;
18 }
19 constexpr METAL_FUNC operator bool() const device {
20 return true;
21 }
22 constexpr METAL_FUNC operator bool() const constant {
23 return true;
24 }
25};
26
27typedef struct _NoMask nomask_t;
28
29template <typename OutT, typename InT = OutT>
30struct ScaleOp {
31 OutT scale;
32
33 METAL_FUNC OutT apply(InT x) const {
34 return static_cast<OutT>(x) * scale;
35 }
36};
37
38template <
39 typename T,
40 typename out_mask_t,
41 typename op_mask_t,
42 const int BM, /* Threadgroup rows (in simdgroups) */
43 const int BN, /* Threadgroup cols (in simdgroups) */
44 const int SM, /* Simdgroup rows (in threads) */
45 const int SN, /* Simdgroup cols (in threads) */
46 const int TM, /* Thread rows (in elements) */
47 const int TN> /* Thread cols (in elements) */
48struct GEMVKernel {
49 MLX_MTL_CONST int threadsM = BM * SM;
50 MLX_MTL_CONST int threadsN = BN * SN;
51
54
55 static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
56
57 static_assert(
58 SN == 8 || SN == 16 || SN == 32,
59 "gemv block must have a width of 8, 16, or 32");
60
61 static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM");
62
63 MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
64 MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
65
67 has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
69 has_output_mask && !metal::is_same_v<out_mask_t, bool>;
70
71 // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
72 // into blocks of (blockM, blockN) divided among threadgroups
73 // - Every thread works on a block of (TM, TN)
74 // - We assume each threadgroup has (threadsN, threadsM, 1) threads
75 //
76 // 1. A thread loads TN elements each from mat along TM rows
77 // and the corresponding scalar from the vector
78 // 2. The thread then multiplies and adds to accumulate its local result for
79 // the block
80 // 3. At the end, each thread has accumulated results over all blocks across
81 // the rows. These are then summed up across the threadgroup
82 // 4. Each threadgroup writes its accumulated blockM outputs
83 //
84 // Edge case handling:
85 // - The threadgroup with the largest tid has blocks that exceed the matrix
86 // * The blocks that start outside the matrix are never read (thread results
87 // remain zero)
88 // * The last thread that partially overlaps with the matrix is shifted
89 // inwards such that the thread block fits exactly in the matrix
90
91 MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
93
94 static METAL_FUNC void
95 load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) {
97 for (int tn = 0; tn < TN; tn++) {
98 dst[tn] = src[src_offset + tn];
99 }
100 }
101
102 static METAL_FUNC void load_safe(
103 const device T* src,
104 thread T dst[TN],
105 const int src_offset = 0,
106 const int src_size = TN) {
107 if (src_offset + TN <= src_size) {
109 for (int tn = 0; tn < TN; tn++) {
110 dst[tn] = src[src_offset + tn];
111 }
112 } else { // Edgecase
114 for (int tn = 0; tn < TN; tn++) {
115 dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0;
116 }
117 }
118 }
119
120 static METAL_FUNC void run(
121 const device T* mat [[buffer(0)]],
122 const device T* in_vec [[buffer(1)]],
123 device T* out_vec [[buffer(3)]],
124 const constant int& in_vec_size [[buffer(4)]],
125 const constant int& out_vec_size [[buffer(5)]],
126 const constant int& matrix_ld [[buffer(6)]],
127 const device out_mask_t* out_mask [[buffer(20)]],
128 const device op_mask_t* mat_mask [[buffer(21)]],
129 const device op_mask_t* vec_mask [[buffer(22)]],
130 const constant int* mask_strides [[buffer(23)]],
131 threadgroup T* tgp_memory [[threadgroup(0)]],
132 uint3 tid [[threadgroup_position_in_grid]],
133 uint3 lid [[thread_position_in_threadgroup]],
134 uint simd_gid [[simdgroup_index_in_threadgroup]],
135 uint simd_lid [[thread_index_in_simdgroup]]) {
136 // Appease compiler
137 (void)lid;
138
139 // Thread local accumulation results
140 thread T result[TM] = {0};
141 thread T inter[TN];
142 thread T v_coeff[TN];
143
144 const int thrM = SN != 32 ? simd_lid / SN : 0;
145 const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
146
147 const int sgN = BN != 1 ? (simd_gid % BN) : 0;
148
149 const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
150 const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
151
152 int bm = (simdM + thrM) * TM;
153 int bn = (simdN + thrN) * TN;
154
155 // Block position
156 int out_row = tid.x * blockM + bm;
157
158 // Exit simdgroup if rows out of bound
159 if (out_row >= out_vec_size)
160 return;
161
162 // Adjust tail simdgroup to ensure in bound reads
163 out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
164
165 // Prepare mask offsets
166 const constant int* out_mask_strides = mask_strides;
167 const constant int* mat_mask_strides =
168 mask_strides + (has_output_mask ? 2 : 0);
169 const constant int* vec_mask_strides =
170 mat_mask_strides + (has_operand_mask ? 2 : 0);
171
172 const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x);
173
174 const int out_mask_offset =
175 !has_output_mask ? 0 : m_block_idx * out_mask_strides[1];
176
177 int mat_mask_offset =
178 !has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1];
179 int vec_mask_offset = 0;
180 const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0];
181 const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1];
182
183 T out_scale{1};
184
185 // Check output mask
186 if (has_output_mask) {
187 auto mask_out = out_mask[out_mask_offset];
188
189 // Write zeros and return if mask is 0
190 if (!mask_out) {
191 if (simdN == 0 && thrN == 0) {
193 for (int tm = 0; tm < TM; tm++) {
194 out_vec[out_row + tm] = T(0.);
195 }
196 }
197
198 return;
199 }
200
201 // Store scalar if multiplicative mask
203 out_scale = T(mask_out);
204 }
205 }
206
207 // Advance matrix
208 mat += out_row * matrix_ld;
209
210 // Prepare for loop
211 constexpr const uniform<int> loop_stride = make_uniform(blockN);
212 const uniform<int> in_size = make_uniform(in_vec_size);
213 const uniform<int> n_iter = in_size / loop_stride;
214 const uniform<int> last_iter = loop_stride * n_iter;
215 const uniform<int> leftover = in_size - last_iter;
216
217 // Loop over in_vec in blocks of blockN
218 for (int i = 0; i < n_iter; ++i) {
219 if (!has_operand_mask ||
220 (bool(mat_mask[mat_mask_offset]) &&
221 bool(vec_mask[vec_mask_offset]))) {
222 T block_scale{1};
224 block_scale =
225 T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
226 }
227
228 load_unsafe(in_vec, v_coeff, bn);
229
230 // Apply scale
233 for (int tn = 0; tn < TN; tn++) {
234 v_coeff[tn] *= block_scale;
235 }
236 }
237
238 // Per thread work loop
239 int mat_offset = 0;
241 for (int tm = 0; tm < TM; tm++) {
242 // Load for the row
243 load_unsafe(mat, inter, mat_offset + bn);
244
245 // Accumulate results
247 for (int tn = 0; tn < TN; tn++) {
248 result[tm] += inter[tn] * v_coeff[tn];
249 }
250
251 mat_offset += matrix_ld;
252 }
253 }
254
255 bn += blockN;
256 mat_mask_offset += mat_mask_step;
257 vec_mask_offset += vec_mask_step;
258 }
259
260 if (leftover > 0 &&
262 (bool(mat_mask[mat_mask_offset]) &&
263 bool(vec_mask[vec_mask_offset])))) {
264 T block_scale{1};
266 block_scale =
267 T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
268 }
269
270 load_safe(in_vec, v_coeff, bn, in_size);
271
272 // Apply scale
275 for (int tn = 0; tn < TN; tn++) {
276 v_coeff[tn] *= block_scale;
277 }
278 }
279
280 // Per thread work loop
282 for (int tm = 0; tm < TM; tm++) {
283 // Load for the row
284 load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
285
286 // Accumulate results
288 for (int tn = 0; tn < TN; tn++) {
289 result[tm] += inter[tn] * v_coeff[tn];
290 }
291 }
292 }
293
294 // Apply out scale
297 for (int tm = 0; tm < TM; tm++) {
298 result[tm] *= out_scale;
299 }
300 }
301
302 // Simdgroup accumulations
304 for (int tm = 0; tm < TM; tm++) {
306 for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
307 result[tm] += simd_shuffle_down(result[tm], sn);
308 }
309 }
310
311 // Threadgroup accumulation results
313 threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
314 if (thrN == 0) {
316 for (int tm = 0; tm < TM; tm++) {
317 tgp_results[tm] = result[tm];
318 }
319
320 threadgroup_barrier(mem_flags::mem_none);
321
322 if (sgN == 0) {
324 for (int sgn = 1; sgn < BN; sgn++) {
326 for (int tm = 0; tm < TM; tm++) {
327 result[tm] += tgp_results[sgn * (blockM + TM) + tm];
328 }
329 }
330 }
331 }
332 }
333
334 // Write outputs
335 if (simdN == 0 && thrN == 0) {
337 for (int tm = 0; tm < TM; tm++) {
338 out_vec[out_row + tm] = result[tm];
339 }
340 }
341 }
342};
343
347
348template <
349 typename T,
350 typename out_mask_t,
351 typename op_mask_t,
352 const int BM, /* Threadgroup rows (in simdgroups) */
353 const int BN, /* Threadgroup cols (in simdgroups) */
354 const int SM, /* Simdgroup rows (in threads) */
355 const int SN, /* Simdgroup cols (in threads) */
356 const int TM, /* Thread rows (in elements) */
357 const int TN> /* Thread cols (in elements) */
359 MLX_MTL_CONST int threadsM = BM * SM;
360 MLX_MTL_CONST int threadsN = BN * SN;
361
364
365 static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
366
367 MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
368 MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
369
371 has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
373 has_output_mask && !metal::is_same_v<out_mask_t, bool>;
374
375 // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
376 // into blocks of (blockM, blockN) divided among threadgroups
377 // - Every thread works on a block of (TM, TN)
378 // - We assume each threadgroup has (threadsN, threadsM, 1) threads
379 //
380 // 1. A thread loads TN elements each from mat along TM contiguous rows
381 // and the corresponding scalar from the vector
382 // 2. The thread then accumulates its local result for the block
383 // 3. At the end, each thread has accumulated results over all blocks across
384 // the rows. These are then summed up across the threadgroup
385 // 4. Each threadgroup writes its accumulated BN * TN outputs
386 //
387 // Edge case handling:
388 // - The threadgroup with the largest tid has blocks that exceed the matrix
389 // * The blocks that start outside the matrix are never read (thread results
390 // remain zero)
391 // * The last thread that partially overlaps with the matrix is shifted
392 // inwards such that the thread block fits exactly in the matrix
393
394 MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
396
397 static METAL_FUNC void run(
398 const device T* mat [[buffer(0)]],
399 const device T* in_vec [[buffer(1)]],
400 device T* out_vec [[buffer(3)]],
401 const constant int& in_vec_size [[buffer(4)]],
402 const constant int& out_vec_size [[buffer(5)]],
403 const constant int& marix_ld [[buffer(6)]],
404 const device out_mask_t* out_mask [[buffer(20)]],
405 const device op_mask_t* mat_mask [[buffer(21)]],
406 const device op_mask_t* vec_mask [[buffer(22)]],
407 const constant int* mask_strides [[buffer(23)]],
408 threadgroup T* tgp_memory [[threadgroup(0)]],
409 uint3 tid [[threadgroup_position_in_grid]],
410 uint3 lid [[thread_position_in_threadgroup]],
411 uint simd_gid [[simdgroup_index_in_threadgroup]],
412 uint simd_lid [[thread_index_in_simdgroup]]) {
413 // Appease compiler
414 (void)lid;
415
416 // Thread local accumulation results
417 T result[TN] = {0};
418 T inter[TN];
419 T v_coeff[TM];
420
421 const int thrM = SN != 32 ? simd_lid / SN : 0;
422 const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
423
424 const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
425 const int sgN = BN != 1 ? (simd_gid % BN) : 0;
426
427 const int simdM = SM * sgM;
428 const int simdN = SN * sgN;
429
430 int cm = (simdM + thrM);
431 int cn = (simdN + thrN);
432
433 int bm = cm * TM;
434 int bn = cn * TN;
435
436 int out_col = tid.x * blockN + bn;
437
438 // Prepare mask offsets
439 const constant int* out_mask_strides = mask_strides;
440 const constant int* mat_mask_strides =
441 out_mask_strides + (has_output_mask ? 2 : 0);
442 const constant int* vec_mask_strides =
443 mat_mask_strides + (has_operand_mask ? 2 : 0);
444
445 const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x);
446
447 const int out_mask_offset =
448 !has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0];
449
450 int mat_mask_offset =
451 !has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0];
452 int vec_mask_offset = 0;
453 const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1];
454 const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0];
455
456 T out_scale{1};
457
458 // Check output mask
459 if (has_output_mask) {
460 auto mask_out = out_mask[out_mask_offset];
461
462 // Write zeros and return if mask is 0
463 if (!mask_out) {
464 if (cm == 0 && out_col < out_vec_size) {
465 if (out_col + TN <= out_vec_size) {
467 for (int tn = 0; tn < TN; tn++) {
468 out_vec[out_col + tn] = T(0.);
469 }
470 } else {
471 for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {
472 out_vec[out_col + tn] = T(0.);
473 }
474 }
475 }
476
477 return;
478 }
479
480 // Store scalar if multiplicative mask
482 out_scale = T(mask_out);
483 }
484 }
485
486 // Prepare for loop
487 constexpr const uniform<int> loop_stride = make_uniform(blockM);
488 const uniform<int> in_size = make_uniform(in_vec_size);
489 const uniform<int> n_iter = in_size / loop_stride;
490 const uniform<int> last_iter = loop_stride * n_iter;
491 const uniform<int> leftover = in_size - last_iter;
492
493 // Edgecase handling
494 if (out_col < out_vec_size) {
495 out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;
496
497 // Per thread accumulation main loop
498 for (int i = 0; i < n_iter; ++i) {
499 // Adding a threadgroup_barrier improves performance slightly
500 // This is possibly it may help exploit cache better
501 threadgroup_barrier(mem_flags::mem_none);
502
503 if (!has_operand_mask ||
504 (bool(mat_mask[mat_mask_offset]) &&
505 bool(vec_mask[vec_mask_offset]))) {
506 T block_scale{1};
508 block_scale =
509 T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
510 }
511
513 for (int tm = 0; tm < TM; tm++) {
514 v_coeff[tm] = in_vec[bm + tm];
515 }
516
517 // Apply scale
520 for (int tm = 0; tm < TM; tm++) {
521 v_coeff[tm] *= block_scale;
522 }
523 }
524
526 for (int tm = 0; tm < TM; tm++) {
527 for (int tn = 0; tn < TN; tn++) {
528 inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
529 }
530 for (int tn = 0; tn < TN; tn++) {
531 result[tn] += v_coeff[tm] * inter[tn];
532 }
533 }
534 }
535
536 bm += blockM;
537 mat_mask_offset += mat_mask_step;
538 vec_mask_offset += vec_mask_step;
539 }
540
541 if (leftover > 0 &&
543 (bool(mat_mask[mat_mask_offset]) &&
544 bool(vec_mask[vec_mask_offset])))) {
545 T block_scale{1};
547 block_scale =
548 T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
549 }
550
551 for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
552 v_coeff[tm] = in_vec[bm + tm];
553
555 v_coeff[tm] *= block_scale;
556 }
557
559 for (int tn = 0; tn < TN; tn++) {
560 inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
561 }
562
564 for (int tn = 0; tn < TN; tn++) {
565 result[tn] += v_coeff[tm] * inter[tn];
566 }
567 }
568 }
569 }
570
571 // Apply out scale
574 for (int tn = 0; tn < TN; tn++) {
575 result[tn] *= out_scale;
576 }
577 }
578
579 // Simdgroup accumulations
581 for (int tn = 0; tn < TN; tn++) {
583 for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
584 result[tn] += simd_shuffle_down(result[tn], SN * sm);
585 }
586 }
587
588 // Threadgroup accumulation results
590 threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
591 if (thrM == 0) {
593 for (int tn = 0; tn < TN; tn++) {
594 tgp_results[tn] = result[tn];
595 }
596
597 threadgroup_barrier(mem_flags::mem_none);
598
599 if (sgM == 0) {
601 for (int sgm = 1; sgm < BM; sgm++) {
603 for (int tn = 0; tn < TN; tn++) {
604 result[tn] += tgp_results[sgm * (blockN + TN) + tn];
605 }
606 }
607 }
608 }
609 }
610
611 // Threadgroup accumulation and writing out results
612 if (cm == 0 && out_col < out_vec_size) {
614 for (int j = 0; j < TN; j++) {
615 out_vec[out_col + j] = result[j];
616 }
617 }
618 }
619};
620
624
625template <
626 typename T,
627 typename out_mask_t,
628 typename op_mask_t,
629 const int BM, /* Threadgroup rows (in simdgroups) */
630 const int BN, /* Threadgroup cols (in simdgroups) */
631 const int SM, /* Simdgroup rows (in threads) */
632 const int SN, /* Simdgroup cols (in threads) */
633 const int TM, /* Thread rows (in elements) */
634 const int TN, /* Thread cols (in elements) */
635 const bool kDoNCBatch> /* Batch ndim > 1 */
636[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked(
637 const device T* mat [[buffer(0)]],
638 const device T* in_vec [[buffer(1)]],
639 device T* out_vec [[buffer(3)]],
640 const constant int& in_vec_size [[buffer(4)]],
641 const constant int& out_vec_size [[buffer(5)]],
642 const constant int& marix_ld [[buffer(6)]],
643 const constant int& batch_ndim [[buffer(9)]],
644 const constant int* batch_shape [[buffer(10)]],
645 const constant size_t* vector_batch_stride [[buffer(11)]],
646 const constant size_t* matrix_batch_stride [[buffer(12)]],
647 const device out_mask_t* out_mask [[buffer(20)]],
648 const device op_mask_t* mat_mask [[buffer(21)]],
649 const device op_mask_t* vec_mask [[buffer(22)]],
650 const constant int* mask_strides [[buffer(23)]],
651 const constant size_t* mask_batch_strides [[buffer(24)]],
652 uint3 tid [[threadgroup_position_in_grid]],
653 uint3 lid [[thread_position_in_threadgroup]],
654 uint simd_gid [[simdgroup_index_in_threadgroup]],
655 uint simd_lid [[thread_index_in_simdgroup]]) {
656 using gemv_kernel =
658 threadgroup T tgp_memory
659 [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
660
661 constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
662 constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
663
664 // Update batch offsets
665 if (kDoNCBatch) {
666 in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
667 mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
668
669 if (has_output_mask) {
670 out_mask +=
671 elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
672 mask_batch_strides += batch_ndim;
673 }
674
675 if (has_operand_mask) {
676 const constant size_t* mask_strides_mat = mask_batch_strides;
677 const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
678
679 ulong2 batch_offsets = elem_to_loc_broadcast(
680 tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
681
682 mat_mask += batch_offsets.x;
683 vec_mask += batch_offsets.y;
684 }
685
686 } else {
687 in_vec += tid.z * vector_batch_stride[0];
688 mat += tid.z * matrix_batch_stride[0];
689
690 if (has_output_mask) {
691 out_mask += tid.z * mask_batch_strides[0];
692 mask_batch_strides += batch_ndim;
693 }
694
695 if (has_operand_mask) {
696 mat_mask += tid.z * mask_batch_strides[0];
697 vec_mask += tid.z * mask_batch_strides[batch_ndim];
698 }
699 }
700
701 out_vec += tid.z * out_vec_size;
702
703 gemv_kernel::run(
704 mat,
705 in_vec,
706 out_vec,
707 in_vec_size,
708 out_vec_size,
709 marix_ld,
710 out_mask,
711 mat_mask,
712 vec_mask,
713 mask_strides,
714 gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
715 tid,
716 lid,
717 simd_gid,
718 simd_lid);
719}
720
724
725template <
726 typename T,
727 typename out_mask_t,
728 typename op_mask_t,
729 const int BM, /* Threadgroup rows (in simdgroups) */
730 const int BN, /* Threadgroup cols (in simdgroups) */
731 const int SM, /* Simdgroup rows (in threads) */
732 const int SN, /* Simdgroup cols (in threads) */
733 const int TM, /* Thread rows (in elements) */
734 const int TN, /* Thread cols (in elements) */
735 const bool kDoNCBatch> /* Batch ndim > 1 */
736[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked(
737 const device T* mat [[buffer(0)]],
738 const device T* in_vec [[buffer(1)]],
739 device T* out_vec [[buffer(3)]],
740 const constant int& in_vec_size [[buffer(4)]],
741 const constant int& out_vec_size [[buffer(5)]],
742 const constant int& marix_ld [[buffer(6)]],
743 const constant int& batch_ndim [[buffer(9)]],
744 const constant int* batch_shape [[buffer(10)]],
745 const constant size_t* vector_batch_stride [[buffer(11)]],
746 const constant size_t* matrix_batch_stride [[buffer(12)]],
747 const device out_mask_t* out_mask [[buffer(20)]],
748 const device op_mask_t* mat_mask [[buffer(21)]],
749 const device op_mask_t* vec_mask [[buffer(22)]],
750 const constant int* mask_strides [[buffer(23)]],
751 const constant size_t* mask_batch_strides [[buffer(24)]],
752 uint3 tid [[threadgroup_position_in_grid]],
753 uint3 lid [[thread_position_in_threadgroup]],
754 uint simd_gid [[simdgroup_index_in_threadgroup]],
755 uint simd_lid [[thread_index_in_simdgroup]]) {
756 using gemv_kernel =
758 threadgroup T tgp_memory
759 [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
760
761 constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
762 constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
763
764 // Update batch offsets
765 if (kDoNCBatch) {
766 in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
767 mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
768
769 if (has_output_mask) {
770 out_mask +=
771 elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
772 mask_batch_strides += batch_ndim;
773 }
774
775 if (has_operand_mask) {
776 const constant size_t* mask_strides_mat = mask_batch_strides;
777 const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
778
779 ulong2 batch_offsets = elem_to_loc_broadcast(
780 tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
781
782 mat_mask += batch_offsets.x;
783 vec_mask += batch_offsets.y;
784 }
785
786 } else {
787 in_vec += tid.z * vector_batch_stride[0];
788 mat += tid.z * matrix_batch_stride[0];
789
790 if (has_output_mask) {
791 out_mask += tid.z * mask_batch_strides[0];
792 mask_batch_strides += batch_ndim;
793 }
794
795 if (has_operand_mask) {
796 mat_mask += tid.z * mask_batch_strides[0];
797 vec_mask += tid.z * mask_batch_strides[batch_ndim];
798 }
799 }
800
801 out_vec += tid.z * out_vec_size;
802
803 gemv_kernel::run(
804 mat,
805 in_vec,
806 out_vec,
807 in_vec_size,
808 out_vec_size,
809 marix_ld,
810 out_mask,
811 mat_mask,
812 vec_mask,
813 mask_strides,
814 gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
815 tid,
816 lid,
817 simd_gid,
818 simd_lid);
819}
METAL_FUNC ulong2 elem_to_loc_broadcast(uint elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
Definition utils.h:7
METAL_FUNC stride_t elem_to_loc(uint elem, device const int *shape, device const stride_t *strides, int ndim)
Definition utils.h:87
#define MLX_MTL_CONST
Definition gemv_masked.h:7
#define MLX_MTL_PRAGMA_UNROLL
Definition gemv_masked.h:8
void gemv_t_masked(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &marix_ld, const constant int &batch_ndim, const constant int *batch_shape, const constant size_t *vector_batch_stride, const constant size_t *matrix_batch_stride, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, const constant size_t *mask_batch_strides, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Vector matrix multiplication.
Definition gemv_masked.h:736
void gemv_masked(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &marix_ld, const constant int &batch_ndim, const constant int *batch_shape, const constant size_t *vector_batch_stride, const constant size_t *matrix_batch_stride, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, const constant size_t *mask_batch_strides, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Matrix vector multiplication.
Definition gemv_masked.h:636
Definition bf16.h:265
METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta)
Definition bf16_math.h:391
Definition gemv_masked.h:10
char x
Definition gemv_masked.h:11
Definition gemv_masked.h:48
static METAL_FUNC void load_safe(const device T *src, thread T dst[TN], const int src_offset=0, const int src_size=TN)
Definition gemv_masked.h:102
static constant constexpr const bool has_mul_output_mask
Definition gemv_masked.h:68
static constant constexpr const int threadsM
Definition gemv_masked.h:49
static constant constexpr const int blockN
Definition gemv_masked.h:53
static constant constexpr const int threadsN
Definition gemv_masked.h:50
static METAL_FUNC void load_unsafe(const device T *src, thread T dst[TN], const int src_offset=0)
Definition gemv_masked.h:95
static constant constexpr const int blockM
Definition gemv_masked.h:52
static constant constexpr const short tgp_mem_size
Definition gemv_masked.h:91
static constant constexpr const bool has_operand_mask
Definition gemv_masked.h:63
static constant constexpr const bool has_output_mask
Definition gemv_masked.h:64
static METAL_FUNC void run(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &matrix_ld, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, threadgroup T *tgp_memory, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Definition gemv_masked.h:120
static constant constexpr const bool has_mul_operand_mask
Definition gemv_masked.h:66
static constant constexpr const bool needs_tgp_reduction
Definition gemv_masked.h:92
Vector matrix multiplication.
Definition gemv_masked.h:358
static constant constexpr const int blockM
Definition gemv_masked.h:362
static constant constexpr const short tgp_mem_size
Definition gemv_masked.h:394
static constant constexpr const int threadsM
Definition gemv_masked.h:359
static METAL_FUNC void run(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &marix_ld, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, threadgroup T *tgp_memory, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Definition gemv_masked.h:397
static constant constexpr const int blockN
Definition gemv_masked.h:363
static constant constexpr const bool has_operand_mask
Definition gemv_masked.h:367
static constant constexpr const bool needs_tgp_reduction
Definition gemv_masked.h:395
static constant constexpr const bool has_mul_operand_mask
Definition gemv_masked.h:370
static constant constexpr const bool has_mul_output_mask
Definition gemv_masked.h:372
static constant constexpr const bool has_output_mask
Definition gemv_masked.h:368
static constant constexpr const int threadsN
Definition gemv_masked.h:360
Definition gemv_masked.h:30
OutT scale
Definition gemv_masked.h:31
METAL_FUNC OutT apply(InT x) const
Definition gemv_masked.h:33