7#define MLX_MTL_CONST static constant constexpr const
8#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
13 constexpr METAL_FUNC
operator bool() {
16 constexpr METAL_FUNC
operator bool() const threadgroup {
19 constexpr METAL_FUNC
operator bool() const device {
22 constexpr METAL_FUNC
operator bool() const constant {
29template <
typename OutT,
typename InT = OutT>
33 METAL_FUNC OutT
apply(InT x)
const {
34 return static_cast<OutT
>(x) *
scale;
55 static_assert(SM * SN == 32,
"simdgroup can only have 32 threads");
58 SN == 8 || SN == 16 || SN == 32,
59 "gemv block must have a width of 8, 16, or 32");
61 static_assert(
blockN >=
blockM,
"Masked gemv must have blockN >= blockM");
94 static METAL_FUNC
void
95 load_unsafe(
const device T* src, thread T dst[TN],
const int src_offset = 0) {
97 for (
int tn = 0; tn < TN; tn++) {
98 dst[tn] = src[src_offset + tn];
105 const int src_offset = 0,
106 const int src_size = TN) {
107 if (src_offset + TN <= src_size) {
109 for (
int tn = 0; tn < TN; tn++) {
110 dst[tn] = src[src_offset + tn];
114 for (
int tn = 0; tn < TN; tn++) {
115 dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0;
120 static METAL_FUNC
void run(
121 const device T* mat [[buffer(0)]],
122 const device T* in_vec [[buffer(1)]],
123 device T* out_vec [[buffer(3)]],
124 const constant
int& in_vec_size [[buffer(4)]],
125 const constant
int& out_vec_size [[buffer(5)]],
126 const constant
int& matrix_ld [[buffer(6)]],
127 const device out_mask_t* out_mask [[buffer(20)]],
128 const device op_mask_t* mat_mask [[buffer(21)]],
129 const device op_mask_t* vec_mask [[buffer(22)]],
130 const constant
int* mask_strides [[buffer(23)]],
131 threadgroup T* tgp_memory [[threadgroup(0)]],
132 uint3 tid [[threadgroup_position_in_grid]],
133 uint3 lid [[thread_position_in_threadgroup]],
134 uint simd_gid [[simdgroup_index_in_threadgroup]],
135 uint simd_lid [[thread_index_in_simdgroup]]) {
140 thread T result[TM] = {0};
142 thread T v_coeff[TN];
144 const int thrM = SN != 32 ? simd_lid / SN : 0;
145 const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
147 const int sgN = BN != 1 ? (simd_gid % BN) : 0;
149 const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
150 const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
152 int bm = (simdM + thrM) * TM;
153 int bn = (simdN + thrN) * TN;
156 int out_row = tid.x *
blockM + bm;
159 if (out_row >= out_vec_size)
163 out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
166 const constant
int* out_mask_strides = mask_strides;
167 const constant
int* mat_mask_strides =
169 const constant
int* vec_mask_strides =
174 const int out_mask_offset =
177 int mat_mask_offset =
179 int vec_mask_offset = 0;
187 auto mask_out = out_mask[out_mask_offset];
191 if (simdN == 0 && thrN == 0) {
193 for (
int tm = 0; tm < TM; tm++) {
194 out_vec[out_row + tm] = T(0.);
203 out_scale = T(mask_out);
208 mat += out_row * matrix_ld;
211 constexpr const uniform<int> loop_stride = make_uniform(
blockN);
212 const uniform<int> in_size = make_uniform(in_vec_size);
213 const uniform<int> n_iter = in_size / loop_stride;
214 const uniform<int> last_iter = loop_stride * n_iter;
215 const uniform<int> leftover = in_size - last_iter;
218 for (
int i = 0; i < n_iter; ++i) {
220 (
bool(mat_mask[mat_mask_offset]) &&
221 bool(vec_mask[vec_mask_offset]))) {
225 T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
233 for (
int tn = 0; tn < TN; tn++) {
234 v_coeff[tn] *= block_scale;
241 for (
int tm = 0; tm < TM; tm++) {
247 for (
int tn = 0; tn < TN; tn++) {
248 result[tm] += inter[tn] * v_coeff[tn];
251 mat_offset += matrix_ld;
256 mat_mask_offset += mat_mask_step;
257 vec_mask_offset += vec_mask_step;
262 (
bool(mat_mask[mat_mask_offset]) &&
263 bool(vec_mask[vec_mask_offset])))) {
267 T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
275 for (
int tn = 0; tn < TN; tn++) {
276 v_coeff[tn] *= block_scale;
282 for (
int tm = 0; tm < TM; tm++) {
284 load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
288 for (
int tn = 0; tn < TN; tn++) {
289 result[tm] += inter[tn] * v_coeff[tn];
297 for (
int tm = 0; tm < TM; tm++) {
298 result[tm] *= out_scale;
304 for (
int tm = 0; tm < TM; tm++) {
306 for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
313 threadgroup T* tgp_results = tgp_memory + sgN * (
blockM + TM) + bm;
316 for (
int tm = 0; tm < TM; tm++) {
317 tgp_results[tm] = result[tm];
320 threadgroup_barrier(mem_flags::mem_none);
324 for (
int sgn = 1; sgn < BN; sgn++) {
326 for (
int tm = 0; tm < TM; tm++) {
327 result[tm] += tgp_results[sgn * (
blockM + TM) + tm];
335 if (simdN == 0 && thrN == 0) {
337 for (
int tm = 0; tm < TM; tm++) {
338 out_vec[out_row + tm] = result[tm];
365 static_assert(SM * SN == 32,
"simdgroup can only have 32 threads");
397 static METAL_FUNC
void run(
398 const device T* mat [[buffer(0)]],
399 const device T* in_vec [[buffer(1)]],
400 device T* out_vec [[buffer(3)]],
401 const constant
int& in_vec_size [[buffer(4)]],
402 const constant
int& out_vec_size [[buffer(5)]],
403 const constant
int& marix_ld [[buffer(6)]],
404 const device out_mask_t* out_mask [[buffer(20)]],
405 const device op_mask_t* mat_mask [[buffer(21)]],
406 const device op_mask_t* vec_mask [[buffer(22)]],
407 const constant
int* mask_strides [[buffer(23)]],
408 threadgroup T* tgp_memory [[threadgroup(0)]],
409 uint3 tid [[threadgroup_position_in_grid]],
410 uint3 lid [[thread_position_in_threadgroup]],
411 uint simd_gid [[simdgroup_index_in_threadgroup]],
412 uint simd_lid [[thread_index_in_simdgroup]]) {
421 const int thrM = SN != 32 ? simd_lid / SN : 0;
422 const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
424 const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
425 const int sgN = BN != 1 ? (simd_gid % BN) : 0;
427 const int simdM = SM * sgM;
428 const int simdN = SN * sgN;
430 int cm = (simdM + thrM);
431 int cn = (simdN + thrN);
436 int out_col = tid.x *
blockN + bn;
439 const constant
int* out_mask_strides = mask_strides;
440 const constant
int* mat_mask_strides =
442 const constant
int* vec_mask_strides =
447 const int out_mask_offset =
450 int mat_mask_offset =
452 int vec_mask_offset = 0;
460 auto mask_out = out_mask[out_mask_offset];
464 if (cm == 0 && out_col < out_vec_size) {
465 if (out_col + TN <= out_vec_size) {
467 for (
int tn = 0; tn < TN; tn++) {
468 out_vec[out_col + tn] = T(0.);
471 for (
int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {
472 out_vec[out_col + tn] = T(0.);
482 out_scale = T(mask_out);
487 constexpr const uniform<int> loop_stride = make_uniform(
blockM);
488 const uniform<int> in_size = make_uniform(in_vec_size);
489 const uniform<int> n_iter = in_size / loop_stride;
490 const uniform<int> last_iter = loop_stride * n_iter;
491 const uniform<int> leftover = in_size - last_iter;
494 if (out_col < out_vec_size) {
495 out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;
498 for (
int i = 0; i < n_iter; ++i) {
501 threadgroup_barrier(mem_flags::mem_none);
504 (
bool(mat_mask[mat_mask_offset]) &&
505 bool(vec_mask[vec_mask_offset]))) {
509 T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
513 for (
int tm = 0; tm < TM; tm++) {
514 v_coeff[tm] = in_vec[bm + tm];
520 for (
int tm = 0; tm < TM; tm++) {
521 v_coeff[tm] *= block_scale;
526 for (
int tm = 0; tm < TM; tm++) {
527 for (
int tn = 0; tn < TN; tn++) {
528 inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
530 for (
int tn = 0; tn < TN; tn++) {
531 result[tn] += v_coeff[tm] * inter[tn];
537 mat_mask_offset += mat_mask_step;
538 vec_mask_offset += vec_mask_step;
543 (
bool(mat_mask[mat_mask_offset]) &&
544 bool(vec_mask[vec_mask_offset])))) {
548 T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
551 for (
int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
552 v_coeff[tm] = in_vec[bm + tm];
555 v_coeff[tm] *= block_scale;
559 for (
int tn = 0; tn < TN; tn++) {
560 inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
564 for (
int tn = 0; tn < TN; tn++) {
565 result[tn] += v_coeff[tm] * inter[tn];
574 for (
int tn = 0; tn < TN; tn++) {
575 result[tn] *= out_scale;
581 for (
int tn = 0; tn < TN; tn++) {
583 for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
590 threadgroup T* tgp_results = tgp_memory + sgM * (
blockN + TN) + bn;
593 for (
int tn = 0; tn < TN; tn++) {
594 tgp_results[tn] = result[tn];
597 threadgroup_barrier(mem_flags::mem_none);
601 for (
int sgm = 1; sgm < BM; sgm++) {
603 for (
int tn = 0; tn < TN; tn++) {
604 result[tn] += tgp_results[sgm * (
blockN + TN) + tn];
612 if (cm == 0 && out_col < out_vec_size) {
614 for (
int j = 0; j < TN; j++) {
615 out_vec[out_col + j] = result[j];
635 const bool kDoNCBatch>
636[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]]
void gemv_masked(
637 const device T* mat [[buffer(0)]],
638 const device T* in_vec [[buffer(1)]],
639 device T* out_vec [[buffer(3)]],
640 const constant
int& in_vec_size [[buffer(4)]],
641 const constant
int& out_vec_size [[buffer(5)]],
642 const constant
int& marix_ld [[buffer(6)]],
643 const constant
int& batch_ndim [[buffer(9)]],
644 const constant
int* batch_shape [[buffer(10)]],
645 const constant
size_t* vector_batch_stride [[buffer(11)]],
646 const constant
size_t* matrix_batch_stride [[buffer(12)]],
647 const device out_mask_t* out_mask [[buffer(20)]],
648 const device op_mask_t* mat_mask [[buffer(21)]],
649 const device op_mask_t* vec_mask [[buffer(22)]],
650 const constant
int* mask_strides [[buffer(23)]],
651 const constant
size_t* mask_batch_strides [[buffer(24)]],
652 uint3 tid [[threadgroup_position_in_grid]],
653 uint3 lid [[thread_position_in_threadgroup]],
654 uint simd_gid [[simdgroup_index_in_threadgroup]],
655 uint simd_lid [[thread_index_in_simdgroup]]) {
658 threadgroup T tgp_memory
659 [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
661 constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
662 constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
666 in_vec +=
elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
667 mat +=
elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
669 if (has_output_mask) {
671 elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
672 mask_batch_strides += batch_ndim;
675 if (has_operand_mask) {
676 const constant
size_t* mask_strides_mat = mask_batch_strides;
677 const constant
size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
680 tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
682 mat_mask += batch_offsets.x;
683 vec_mask += batch_offsets.y;
687 in_vec += tid.z * vector_batch_stride[0];
688 mat += tid.z * matrix_batch_stride[0];
690 if (has_output_mask) {
691 out_mask += tid.z * mask_batch_strides[0];
692 mask_batch_strides += batch_ndim;
695 if (has_operand_mask) {
696 mat_mask += tid.z * mask_batch_strides[0];
697 vec_mask += tid.z * mask_batch_strides[batch_ndim];
701 out_vec += tid.z * out_vec_size;
714 gemv_kernel::tgp_mem_size == 0 ?
nullptr : tgp_memory,
735 const bool kDoNCBatch>
736[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]]
void gemv_t_masked(
737 const device T* mat [[buffer(0)]],
738 const device T* in_vec [[buffer(1)]],
739 device T* out_vec [[buffer(3)]],
740 const constant
int& in_vec_size [[buffer(4)]],
741 const constant
int& out_vec_size [[buffer(5)]],
742 const constant
int& marix_ld [[buffer(6)]],
743 const constant
int& batch_ndim [[buffer(9)]],
744 const constant
int* batch_shape [[buffer(10)]],
745 const constant
size_t* vector_batch_stride [[buffer(11)]],
746 const constant
size_t* matrix_batch_stride [[buffer(12)]],
747 const device out_mask_t* out_mask [[buffer(20)]],
748 const device op_mask_t* mat_mask [[buffer(21)]],
749 const device op_mask_t* vec_mask [[buffer(22)]],
750 const constant
int* mask_strides [[buffer(23)]],
751 const constant
size_t* mask_batch_strides [[buffer(24)]],
752 uint3 tid [[threadgroup_position_in_grid]],
753 uint3 lid [[thread_position_in_threadgroup]],
754 uint simd_gid [[simdgroup_index_in_threadgroup]],
755 uint simd_lid [[thread_index_in_simdgroup]]) {
758 threadgroup T tgp_memory
759 [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
761 constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
762 constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
766 in_vec +=
elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
767 mat +=
elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
769 if (has_output_mask) {
771 elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
772 mask_batch_strides += batch_ndim;
775 if (has_operand_mask) {
776 const constant
size_t* mask_strides_mat = mask_batch_strides;
777 const constant
size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
780 tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
782 mat_mask += batch_offsets.x;
783 vec_mask += batch_offsets.y;
787 in_vec += tid.z * vector_batch_stride[0];
788 mat += tid.z * matrix_batch_stride[0];
790 if (has_output_mask) {
791 out_mask += tid.z * mask_batch_strides[0];
792 mask_batch_strides += batch_ndim;
795 if (has_operand_mask) {
796 mat_mask += tid.z * mask_batch_strides[0];
797 vec_mask += tid.z * mask_batch_strides[batch_ndim];
801 out_vec += tid.z * out_vec_size;
814 gemv_kernel::tgp_mem_size == 0 ?
nullptr : tgp_memory,
#define MLX_MTL_CONST
Definition gemv_masked.h:7
#define MLX_MTL_PRAGMA_UNROLL
Definition gemv_masked.h:8
void gemv_t_masked(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &marix_ld, const constant int &batch_ndim, const constant int *batch_shape, const constant size_t *vector_batch_stride, const constant size_t *matrix_batch_stride, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, const constant size_t *mask_batch_strides, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Vector matrix multiplication.
Definition gemv_masked.h:736
void gemv_masked(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &marix_ld, const constant int &batch_ndim, const constant int *batch_shape, const constant size_t *vector_batch_stride, const constant size_t *matrix_batch_stride, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, const constant size_t *mask_batch_strides, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Matrix vector multiplication.
Definition gemv_masked.h:636
Definition gemv_masked.h:10
char x
Definition gemv_masked.h:11
Definition gemv_masked.h:48
static METAL_FUNC void load_safe(const device T *src, thread T dst[TN], const int src_offset=0, const int src_size=TN)
Definition gemv_masked.h:102
static constant constexpr const bool has_mul_output_mask
Definition gemv_masked.h:68
static constant constexpr const int threadsM
Definition gemv_masked.h:49
static constant constexpr const int blockN
Definition gemv_masked.h:53
static constant constexpr const int threadsN
Definition gemv_masked.h:50
static METAL_FUNC void load_unsafe(const device T *src, thread T dst[TN], const int src_offset=0)
Definition gemv_masked.h:95
static constant constexpr const int blockM
Definition gemv_masked.h:52
static constant constexpr const short tgp_mem_size
Definition gemv_masked.h:91
static constant constexpr const bool has_operand_mask
Definition gemv_masked.h:63
static constant constexpr const bool has_output_mask
Definition gemv_masked.h:64
static METAL_FUNC void run(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &matrix_ld, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, threadgroup T *tgp_memory, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Definition gemv_masked.h:120
static constant constexpr const bool has_mul_operand_mask
Definition gemv_masked.h:66
static constant constexpr const bool needs_tgp_reduction
Definition gemv_masked.h:92
Vector matrix multiplication.
Definition gemv_masked.h:358
static constant constexpr const int blockM
Definition gemv_masked.h:362
static constant constexpr const short tgp_mem_size
Definition gemv_masked.h:394
static constant constexpr const int threadsM
Definition gemv_masked.h:359
static METAL_FUNC void run(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &marix_ld, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, threadgroup T *tgp_memory, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Definition gemv_masked.h:397
static constant constexpr const int blockN
Definition gemv_masked.h:363
static constant constexpr const bool has_operand_mask
Definition gemv_masked.h:367
static constant constexpr const bool needs_tgp_reduction
Definition gemv_masked.h:395
static constant constexpr const bool has_mul_operand_mask
Definition gemv_masked.h:370
static constant constexpr const bool has_mul_output_mask
Definition gemv_masked.h:372
static constant constexpr const bool has_output_mask
Definition gemv_masked.h:368
static constant constexpr const int threadsN
Definition gemv_masked.h:360
Definition gemv_masked.h:30
OutT scale
Definition gemv_masked.h:31
METAL_FUNC OutT apply(InT x) const
Definition gemv_masked.h:33