MLX
 
Loading...
Searching...
No Matches
gemv_masked.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
4
5using namespace metal;
6
7#define MLX_MTL_CONST static constant constexpr const
8#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
9
10struct _NoMask {
11 char x;
12
13 constexpr METAL_FUNC operator bool() {
14 return true;
15 }
16 constexpr METAL_FUNC operator bool() const threadgroup {
17 return true;
18 }
19 constexpr METAL_FUNC operator bool() const device {
20 return true;
21 }
22 constexpr METAL_FUNC operator bool() const constant {
23 return true;
24 }
25};
26
27typedef struct _NoMask nomask_t;
28
29template <typename OutT, typename InT = OutT>
30struct ScaleOp {
31 OutT scale;
32
33 METAL_FUNC OutT apply(InT x) const {
34 return static_cast<OutT>(x) * scale;
35 }
36};
37
38template <
39 typename T,
40 typename out_mask_t,
41 typename op_mask_t,
42 const int BM, /* Threadgroup rows (in simdgroups) */
43 const int BN, /* Threadgroup cols (in simdgroups) */
44 const int SM, /* Simdgroup rows (in threads) */
45 const int SN, /* Simdgroup cols (in threads) */
46 const int TM, /* Thread rows (in elements) */
47 const int TN, /* Thread cols (in elements) */
48 typename AccT = float>
49struct GEMVKernel {
50 MLX_MTL_CONST int threadsM = BM * SM;
51 MLX_MTL_CONST int threadsN = BN * SN;
52
55
56 static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
57
58 static_assert(
59 SN == 8 || SN == 16 || SN == 32,
60 "gemv block must have a width of 8, 16, or 32");
61
62 static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM");
63
64 MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
65 MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
66
68 has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
70 has_output_mask && !metal::is_same_v<out_mask_t, bool>;
71
72 // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
73 // into blocks of (blockM, blockN) divided among threadgroups
74 // - Every thread works on a block of (TM, TN)
75 // - We assume each threadgroup has (threadsN, threadsM, 1) threads
76 //
77 // 1. A thread loads TN elements each from mat along TM rows
78 // and the corresponding scalar from the vector
79 // 2. The thread then multiplies and adds to accumulate its local result for
80 // the block
81 // 3. At the end, each thread has accumulated results over all blocks across
82 // the rows. These are then summed up across the threadgroup
83 // 4. Each threadgroup writes its accumulated blockM outputs
84 //
85 // Edge case handling:
86 // - The threadgroup with the largest tid has blocks that exceed the matrix
87 // * The blocks that start outside the matrix are never read (thread results
88 // remain zero)
89 // * The last thread that partially overlaps with the matrix is shifted
90 // inwards such that the thread block fits exactly in the matrix
91
92 MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
94
95 template <typename U = T>
96 static METAL_FUNC void
97 load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) {
99 for (int tn = 0; tn < TN; tn++) {
100 dst[tn] = static_cast<U>(src[src_offset + tn]);
101 }
102 }
103
104 template <typename U = T>
105 static METAL_FUNC void load_safe(
106 const device T* src,
107 thread U dst[TN],
108 const int src_offset = 0,
109 const int src_size = TN) {
110 if (src_offset + TN <= src_size) {
112 for (int tn = 0; tn < TN; tn++) {
113 dst[tn] = static_cast<U>(src[src_offset + tn]);
114 }
115 } else { // Edgecase
117 for (int tn = 0; tn < TN; tn++) {
118 dst[tn] = src_offset + tn < src_size
119 ? static_cast<U>(src[src_offset + tn])
120 : U(0);
121 }
122 }
123 }
124
125 static METAL_FUNC void run(
126 const device T* mat [[buffer(0)]],
127 const device T* in_vec [[buffer(1)]],
128 device T* out_vec [[buffer(3)]],
129 const constant int& in_vec_size [[buffer(4)]],
130 const constant int& out_vec_size [[buffer(5)]],
131 const constant int& matrix_ld [[buffer(6)]],
132 const device out_mask_t* out_mask [[buffer(20)]],
133 const device op_mask_t* mat_mask [[buffer(21)]],
134 const device op_mask_t* vec_mask [[buffer(22)]],
135 const constant int* mask_strides [[buffer(23)]],
136 threadgroup AccT* tgp_memory [[threadgroup(0)]],
137 uint3 tid [[threadgroup_position_in_grid]],
138 uint3 lid [[thread_position_in_threadgroup]],
139 uint simd_gid [[simdgroup_index_in_threadgroup]],
140 uint simd_lid [[thread_index_in_simdgroup]]) {
141 // Appease compiler
142 (void)lid;
143
144 // Thread local accumulation results
145 thread AccT result[TM] = {0};
146 thread T inter[TN];
147 thread AccT v_coeff[TN];
148
149 const int thrM = SN != 32 ? simd_lid / SN : 0;
150 const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
151
152 const int sgN = BN != 1 ? (simd_gid % BN) : 0;
153
154 const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
155 const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
156
157 int bm = (simdM + thrM) * TM;
158 int bn = (simdN + thrN) * TN;
159
160 // Block position
161 int out_row = tid.x * blockM + bm;
162
163 // Exit simdgroup if rows out of bound
164 if (out_row >= out_vec_size)
165 return;
166
167 // Adjust tail simdgroup to ensure in bound reads
168 out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
169
170 // Prepare mask offsets
171 const constant int* out_mask_strides = mask_strides;
172 const constant int* mat_mask_strides =
173 mask_strides + (has_output_mask ? 2 : 0);
174 const constant int* vec_mask_strides =
175 mat_mask_strides + (has_operand_mask ? 2 : 0);
176
177 const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x);
178
179 const int out_mask_offset =
180 !has_output_mask ? 0 : m_block_idx * out_mask_strides[1];
181
182 int mat_mask_offset =
183 !has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1];
184 int vec_mask_offset = 0;
185 const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0];
186 const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1];
187
188 T out_scale{1};
189
190 // Check output mask
191 if (has_output_mask) {
192 auto mask_out = out_mask[out_mask_offset];
193
194 // Write zeros and return if mask is 0
195 if (!mask_out) {
196 if (simdN == 0 && thrN == 0) {
198 for (int tm = 0; tm < TM; tm++) {
199 out_vec[out_row + tm] = T(0.);
200 }
201 }
202
203 return;
204 }
205
206 // Store scalar if multiplicative mask
208 out_scale = T(mask_out);
209 }
210 }
211
212 // Advance matrix
213 mat += out_row * matrix_ld;
214
215 // Prepare for loop
216 constexpr const uniform<int> loop_stride = make_uniform(blockN);
217 const uniform<int> in_size = make_uniform(in_vec_size);
218 const uniform<int> n_iter = in_size / loop_stride;
219 const uniform<int> last_iter = loop_stride * n_iter;
220 const uniform<int> leftover = in_size - last_iter;
221
222 // Loop over in_vec in blocks of blockN
223 for (int i = 0; i < n_iter; ++i) {
224 if (!has_operand_mask ||
225 (bool(mat_mask[mat_mask_offset]) &&
226 bool(vec_mask[vec_mask_offset]))) {
227 T block_scale{1};
229 block_scale =
230 T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
231 }
232
233 load_unsafe<AccT>(in_vec, v_coeff, bn);
234
235 // Apply scale
238 for (int tn = 0; tn < TN; tn++) {
239 v_coeff[tn] *= block_scale;
240 }
241 }
242
243 // Per thread work loop
244 int mat_offset = 0;
246 for (int tm = 0; tm < TM; tm++) {
247 // Load for the row
248 load_unsafe(mat, inter, mat_offset + bn);
249
250 // Accumulate results
252 for (int tn = 0; tn < TN; tn++) {
253 result[tm] += inter[tn] * v_coeff[tn];
254 }
255
256 mat_offset += matrix_ld;
257 }
258 }
259
260 bn += blockN;
261 mat_mask_offset += mat_mask_step;
262 vec_mask_offset += vec_mask_step;
263 }
264
265 if (leftover > 0 &&
267 (bool(mat_mask[mat_mask_offset]) &&
268 bool(vec_mask[vec_mask_offset])))) {
269 T block_scale{1};
271 block_scale =
272 T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
273 }
274
275 load_safe<AccT>(in_vec, v_coeff, bn, in_size);
276
277 // Apply scale
280 for (int tn = 0; tn < TN; tn++) {
281 v_coeff[tn] *= block_scale;
282 }
283 }
284
285 // Per thread work loop
287 for (int tm = 0; tm < TM; tm++) {
288 // Load for the row
289 load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
290
291 // Accumulate results
293 for (int tn = 0; tn < TN; tn++) {
294 result[tm] += inter[tn] * v_coeff[tn];
295 }
296 }
297 }
298
299 // Apply out scale
302 for (int tm = 0; tm < TM; tm++) {
303 result[tm] *= out_scale;
304 }
305 }
306
307 // Simdgroup accumulations
309 for (int tm = 0; tm < TM; tm++) {
311 for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
312 result[tm] += simd_shuffle_down(result[tm], sn);
313 }
314 }
315
316 // Threadgroup accumulation results
318 threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
319 if (thrN == 0) {
321 for (int tm = 0; tm < TM; tm++) {
322 tgp_results[tm] = result[tm];
323 }
324
325 threadgroup_barrier(mem_flags::mem_none);
326
327 if (sgN == 0) {
329 for (int sgn = 1; sgn < BN; sgn++) {
331 for (int tm = 0; tm < TM; tm++) {
332 result[tm] += tgp_results[sgn * (blockM + TM) + tm];
333 }
334 }
335 }
336 }
337 }
338
339 // Write outputs
340 if (simdN == 0 && thrN == 0) {
342 for (int tm = 0; tm < TM; tm++) {
343 out_vec[out_row + tm] = static_cast<T>(result[tm]);
344 }
345 }
346 }
347};
348
352
353template <
354 typename T,
355 typename out_mask_t,
356 typename op_mask_t,
357 const int BM, /* Threadgroup rows (in simdgroups) */
358 const int BN, /* Threadgroup cols (in simdgroups) */
359 const int SM, /* Simdgroup rows (in threads) */
360 const int SN, /* Simdgroup cols (in threads) */
361 const int TM, /* Thread rows (in elements) */
362 const int TN, /* Thread cols (in elements) */
363 typename AccT = float>
365 MLX_MTL_CONST int threadsM = BM * SM;
366 MLX_MTL_CONST int threadsN = BN * SN;
367
370
371 static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
372
373 MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
374 MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
375
377 has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
379 has_output_mask && !metal::is_same_v<out_mask_t, bool>;
380
381 // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
382 // into blocks of (blockM, blockN) divided among threadgroups
383 // - Every thread works on a block of (TM, TN)
384 // - We assume each threadgroup has (threadsN, threadsM, 1) threads
385 //
386 // 1. A thread loads TN elements each from mat along TM contiguous rows
387 // and the corresponding scalar from the vector
388 // 2. The thread then accumulates its local result for the block
389 // 3. At the end, each thread has accumulated results over all blocks across
390 // the rows. These are then summed up across the threadgroup
391 // 4. Each threadgroup writes its accumulated BN * TN outputs
392 //
393 // Edge case handling:
394 // - The threadgroup with the largest tid has blocks that exceed the matrix
395 // * The blocks that start outside the matrix are never read (thread results
396 // remain zero)
397 // * The last thread that partially overlaps with the matrix is shifted
398 // inwards such that the thread block fits exactly in the matrix
399
400 MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
402
403 static METAL_FUNC void run(
404 const device T* mat [[buffer(0)]],
405 const device T* in_vec [[buffer(1)]],
406 device T* out_vec [[buffer(3)]],
407 const constant int& in_vec_size [[buffer(4)]],
408 const constant int& out_vec_size [[buffer(5)]],
409 const constant int& marix_ld [[buffer(6)]],
410 const device out_mask_t* out_mask [[buffer(20)]],
411 const device op_mask_t* mat_mask [[buffer(21)]],
412 const device op_mask_t* vec_mask [[buffer(22)]],
413 const constant int* mask_strides [[buffer(23)]],
414 threadgroup AccT* tgp_memory [[threadgroup(0)]],
415 uint3 tid [[threadgroup_position_in_grid]],
416 uint3 lid [[thread_position_in_threadgroup]],
417 uint simd_gid [[simdgroup_index_in_threadgroup]],
418 uint simd_lid [[thread_index_in_simdgroup]]) {
419 // Appease compiler
420 (void)lid;
421
422 // Thread local accumulation results
423 AccT result[TN] = {0};
424 T inter[TN];
425 AccT v_coeff[TM];
426
427 const int thrM = SN != 32 ? simd_lid / SN : 0;
428 const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
429
430 const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
431 const int sgN = BN != 1 ? (simd_gid % BN) : 0;
432
433 const int simdM = SM * sgM;
434 const int simdN = SN * sgN;
435
436 int cm = (simdM + thrM);
437 int cn = (simdN + thrN);
438
439 int bm = cm * TM;
440 int bn = cn * TN;
441
442 int out_col = tid.x * blockN + bn;
443
444 // Prepare mask offsets
445 const constant int* out_mask_strides = mask_strides;
446 const constant int* mat_mask_strides =
447 out_mask_strides + (has_output_mask ? 2 : 0);
448 const constant int* vec_mask_strides =
449 mat_mask_strides + (has_operand_mask ? 2 : 0);
450
451 const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x);
452
453 const int out_mask_offset =
454 !has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0];
455
456 int mat_mask_offset =
457 !has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0];
458 int vec_mask_offset = 0;
459 const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1];
460 const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0];
461
462 T out_scale{1};
463
464 // Check output mask
465 if (has_output_mask) {
466 auto mask_out = out_mask[out_mask_offset];
467
468 // Write zeros and return if mask is 0
469 if (!mask_out) {
470 if (cm == 0 && out_col < out_vec_size) {
471 if (out_col + TN <= out_vec_size) {
473 for (int tn = 0; tn < TN; tn++) {
474 out_vec[out_col + tn] = T(0.);
475 }
476 } else {
477 for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {
478 out_vec[out_col + tn] = T(0.);
479 }
480 }
481 }
482
483 return;
484 }
485
486 // Store scalar if multiplicative mask
488 out_scale = T(mask_out);
489 }
490 }
491
492 // Prepare for loop
493 constexpr const uniform<int> loop_stride = make_uniform(blockM);
494 const uniform<int> in_size = make_uniform(in_vec_size);
495 const uniform<int> n_iter = in_size / loop_stride;
496 const uniform<int> last_iter = loop_stride * n_iter;
497 const uniform<int> leftover = in_size - last_iter;
498
499 // Edgecase handling
500 if (out_col < out_vec_size) {
501 out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;
502
503 // Per thread accumulation main loop
504 for (int i = 0; i < n_iter; ++i) {
505 // Adding a threadgroup_barrier improves performance slightly
506 // This is possibly it may help exploit cache better
507 threadgroup_barrier(mem_flags::mem_none);
508
509 if (!has_operand_mask ||
510 (bool(mat_mask[mat_mask_offset]) &&
511 bool(vec_mask[vec_mask_offset]))) {
512 T block_scale{1};
514 block_scale =
515 T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
516 }
517
519 for (int tm = 0; tm < TM; tm++) {
520 v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
521 }
522
523 // Apply scale
526 for (int tm = 0; tm < TM; tm++) {
527 v_coeff[tm] *= block_scale;
528 }
529 }
530
532 for (int tm = 0; tm < TM; tm++) {
533 for (int tn = 0; tn < TN; tn++) {
534 inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
535 }
536 for (int tn = 0; tn < TN; tn++) {
537 result[tn] += v_coeff[tm] * inter[tn];
538 }
539 }
540 }
541
542 bm += blockM;
543 mat_mask_offset += mat_mask_step;
544 vec_mask_offset += vec_mask_step;
545 }
546
547 if (leftover > 0 &&
549 (bool(mat_mask[mat_mask_offset]) &&
550 bool(vec_mask[vec_mask_offset])))) {
551 T block_scale{1};
553 block_scale =
554 T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
555 }
556
557 for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
558 v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
559
561 v_coeff[tm] *= block_scale;
562 }
563
565 for (int tn = 0; tn < TN; tn++) {
566 inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
567 }
568
570 for (int tn = 0; tn < TN; tn++) {
571 result[tn] += v_coeff[tm] * inter[tn];
572 }
573 }
574 }
575 }
576
577 // Apply out scale
580 for (int tn = 0; tn < TN; tn++) {
581 result[tn] *= out_scale;
582 }
583 }
584
585 // Simdgroup accumulations
587 for (int tn = 0; tn < TN; tn++) {
589 for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
590 result[tn] += simd_shuffle_down(result[tn], SN * sm);
591 }
592 }
593
594 // Threadgroup accumulation results
596 threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
597 if (thrM == 0) {
599 for (int tn = 0; tn < TN; tn++) {
600 tgp_results[tn] = result[tn];
601 }
602
603 threadgroup_barrier(mem_flags::mem_none);
604
605 if (sgM == 0) {
607 for (int sgm = 1; sgm < BM; sgm++) {
609 for (int tn = 0; tn < TN; tn++) {
610 result[tn] += tgp_results[sgm * (blockN + TN) + tn];
611 }
612 }
613 }
614 }
615 }
616
617 // Threadgroup accumulation and writing out results
618 if (cm == 0 && out_col < out_vec_size) {
620 for (int j = 0; j < TN; j++) {
621 out_vec[out_col + j] = static_cast<T>(result[j]);
622 }
623 }
624 }
625};
626
630
631template <
632 typename T,
633 typename out_mask_t,
634 typename op_mask_t,
635 const int BM, /* Threadgroup rows (in simdgroups) */
636 const int BN, /* Threadgroup cols (in simdgroups) */
637 const int SM, /* Simdgroup rows (in threads) */
638 const int SN, /* Simdgroup cols (in threads) */
639 const int TM, /* Thread rows (in elements) */
640 const int TN, /* Thread cols (in elements) */
641 const bool kDoNCBatch> /* Batch ndim > 1 */
642[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked(
643 const device T* mat [[buffer(0)]],
644 const device T* in_vec [[buffer(1)]],
645 device T* out_vec [[buffer(3)]],
646 const constant int& in_vec_size [[buffer(4)]],
647 const constant int& out_vec_size [[buffer(5)]],
648 const constant int& marix_ld [[buffer(6)]],
649 const constant int& batch_ndim [[buffer(9)]],
650 const constant int* batch_shape [[buffer(10)]],
651 const constant int64_t* vector_batch_stride [[buffer(11)]],
652 const constant int64_t* matrix_batch_stride [[buffer(12)]],
653 const device out_mask_t* out_mask [[buffer(20)]],
654 const device op_mask_t* mat_mask [[buffer(21)]],
655 const device op_mask_t* vec_mask [[buffer(22)]],
656 const constant int* mask_strides [[buffer(23)]],
657 const constant int64_t* mask_batch_strides [[buffer(24)]],
658 uint3 tid [[threadgroup_position_in_grid]],
659 uint3 lid [[thread_position_in_threadgroup]],
660 uint simd_gid [[simdgroup_index_in_threadgroup]],
661 uint simd_lid [[thread_index_in_simdgroup]]) {
662 using gemv_kernel =
664 threadgroup float tgp_memory
665 [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
666
667 constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
668 constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
669
670 // Update batch offsets
671 if (kDoNCBatch) {
672 in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
673 mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
674
675 if (has_output_mask) {
676 out_mask +=
677 elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
678 mask_batch_strides += batch_ndim;
679 }
680
681 if (has_operand_mask) {
682 const constant auto* mask_strides_mat = mask_batch_strides;
683 const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
684
685 ulong2 batch_offsets = elem_to_loc_broadcast(
686 tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
687
688 mat_mask += batch_offsets.x;
689 vec_mask += batch_offsets.y;
690 }
691
692 } else {
693 in_vec += tid.z * vector_batch_stride[0];
694 mat += tid.z * matrix_batch_stride[0];
695
696 if (has_output_mask) {
697 out_mask += tid.z * mask_batch_strides[0];
698 mask_batch_strides += batch_ndim;
699 }
700
701 if (has_operand_mask) {
702 mat_mask += tid.z * mask_batch_strides[0];
703 vec_mask += tid.z * mask_batch_strides[batch_ndim];
704 }
705 }
706
707 out_vec += tid.z * out_vec_size;
708
709 gemv_kernel::run(
710 mat,
711 in_vec,
712 out_vec,
713 in_vec_size,
714 out_vec_size,
715 marix_ld,
716 out_mask,
717 mat_mask,
718 vec_mask,
719 mask_strides,
720 gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
721 tid,
722 lid,
723 simd_gid,
724 simd_lid);
725}
726
730
731template <
732 typename T,
733 typename out_mask_t,
734 typename op_mask_t,
735 const int BM, /* Threadgroup rows (in simdgroups) */
736 const int BN, /* Threadgroup cols (in simdgroups) */
737 const int SM, /* Simdgroup rows (in threads) */
738 const int SN, /* Simdgroup cols (in threads) */
739 const int TM, /* Thread rows (in elements) */
740 const int TN, /* Thread cols (in elements) */
741 const bool kDoNCBatch> /* Batch ndim > 1 */
742[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked(
743 const device T* mat [[buffer(0)]],
744 const device T* in_vec [[buffer(1)]],
745 device T* out_vec [[buffer(3)]],
746 const constant int& in_vec_size [[buffer(4)]],
747 const constant int& out_vec_size [[buffer(5)]],
748 const constant int& marix_ld [[buffer(6)]],
749 const constant int& batch_ndim [[buffer(9)]],
750 const constant int* batch_shape [[buffer(10)]],
751 const constant int64_t* vector_batch_stride [[buffer(11)]],
752 const constant int64_t* matrix_batch_stride [[buffer(12)]],
753 const device out_mask_t* out_mask [[buffer(20)]],
754 const device op_mask_t* mat_mask [[buffer(21)]],
755 const device op_mask_t* vec_mask [[buffer(22)]],
756 const constant int* mask_strides [[buffer(23)]],
757 const constant int64_t* mask_batch_strides [[buffer(24)]],
758 uint3 tid [[threadgroup_position_in_grid]],
759 uint3 lid [[thread_position_in_threadgroup]],
760 uint simd_gid [[simdgroup_index_in_threadgroup]],
761 uint simd_lid [[thread_index_in_simdgroup]]) {
762 using gemv_kernel =
764 threadgroup float tgp_memory
765 [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
766
767 constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
768 constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
769
770 // Update batch offsets
771 if (kDoNCBatch) {
772 in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
773 mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
774
775 if (has_output_mask) {
776 out_mask +=
777 elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
778 mask_batch_strides += batch_ndim;
779 }
780
781 if (has_operand_mask) {
782 const constant auto* mask_strides_mat = mask_batch_strides;
783 const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
784
785 ulong2 batch_offsets = elem_to_loc_broadcast(
786 tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
787
788 mat_mask += batch_offsets.x;
789 vec_mask += batch_offsets.y;
790 }
791
792 } else {
793 in_vec += tid.z * vector_batch_stride[0];
794 mat += tid.z * matrix_batch_stride[0];
795
796 if (has_output_mask) {
797 out_mask += tid.z * mask_batch_strides[0];
798 mask_batch_strides += batch_ndim;
799 }
800
801 if (has_operand_mask) {
802 mat_mask += tid.z * mask_batch_strides[0];
803 vec_mask += tid.z * mask_batch_strides[batch_ndim];
804 }
805 }
806
807 out_vec += tid.z * out_vec_size;
808
809 gemv_kernel::run(
810 mat,
811 in_vec,
812 out_vec,
813 in_vec_size,
814 out_vec_size,
815 marix_ld,
816 out_mask,
817 mat_mask,
818 vec_mask,
819 mask_strides,
820 gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
821 tid,
822 lid,
823 simd_gid,
824 simd_lid);
825}
METAL_FUNC ulong2 elem_to_loc_broadcast(uint elem, constant const int *shape, constant const int64_t *a_strides, constant const int64_t *b_strides, int ndim)
Definition utils.h:7
METAL_FUNC IdxT elem_to_loc(IdxT elem, constant const int *shape, constant const int64_t *strides, int ndim)
Definition utils.h:93
#define MLX_MTL_CONST
Definition gemv_masked.h:7
#define MLX_MTL_PRAGMA_UNROLL
Definition gemv_masked.h:8
struct _NoMask nomask_t
Definition gemv_masked.h:27
void gemv_t_masked(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &marix_ld, const constant int &batch_ndim, const constant int *batch_shape, const constant int64_t *vector_batch_stride, const constant int64_t *matrix_batch_stride, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, const constant int64_t *mask_batch_strides, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Vector matrix multiplication.
Definition gemv_masked.h:742
void gemv_masked(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &marix_ld, const constant int &batch_ndim, const constant int *batch_shape, const constant int64_t *vector_batch_stride, const constant int64_t *matrix_batch_stride, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, const constant int64_t *mask_batch_strides, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Matrix vector multiplication.
Definition gemv_masked.h:642
Definition bf16_math.h:226
METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta)
Definition bf16_math.h:377
Definition gemv_masked.h:10
char x
Definition gemv_masked.h:11
Definition gemv_masked.h:49
static METAL_FUNC void load_safe(const device T *src, thread U dst[TN], const int src_offset=0, const int src_size=TN)
Definition gemv_masked.h:105
static constant constexpr const bool has_mul_operand_mask
Definition gemv_masked.h:67
static constant constexpr const int blockM
Definition gemv_masked.h:53
static constant constexpr const int threadsN
Definition gemv_masked.h:51
static constant constexpr const bool has_output_mask
Definition gemv_masked.h:65
static constant constexpr const short tgp_mem_size
Definition gemv_masked.h:92
static constant constexpr const bool has_operand_mask
Definition gemv_masked.h:64
static METAL_FUNC void run(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &matrix_ld, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, threadgroup AccT *tgp_memory, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Definition gemv_masked.h:125
static constant constexpr const bool has_mul_output_mask
Definition gemv_masked.h:69
static constant constexpr const bool needs_tgp_reduction
Definition gemv_masked.h:93
static constant constexpr const int blockN
Definition gemv_masked.h:54
static METAL_FUNC void load_unsafe(const device T *src, thread U dst[TN], const int src_offset=0)
Definition gemv_masked.h:97
static constant constexpr const int threadsM
Definition gemv_masked.h:50
Vector matrix multiplication.
Definition gemv_masked.h:364
static constant constexpr const int threadsM
Definition gemv_masked.h:365
static constant constexpr const bool needs_tgp_reduction
Definition gemv_masked.h:401
static constant constexpr const bool has_output_mask
Definition gemv_masked.h:374
static constant constexpr const bool has_mul_output_mask
Definition gemv_masked.h:378
static METAL_FUNC void run(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &marix_ld, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, threadgroup AccT *tgp_memory, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Definition gemv_masked.h:403
static constant constexpr const int blockN
Definition gemv_masked.h:369
static constant constexpr const int threadsN
Definition gemv_masked.h:366
static constant constexpr const short tgp_mem_size
Definition gemv_masked.h:400
static constant constexpr const bool has_mul_operand_mask
Definition gemv_masked.h:376
static constant constexpr const int blockM
Definition gemv_masked.h:368
static constant constexpr const bool has_operand_mask
Definition gemv_masked.h:373
Definition gemv_masked.h:30
OutT scale
Definition gemv_masked.h:31
METAL_FUNC OutT apply(InT x) const
Definition gemv_masked.h:33