7#define MLX_MTL_CONST static constant constexpr const
8#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
13 constexpr METAL_FUNC
operator bool() {
16 constexpr METAL_FUNC
operator bool() const threadgroup {
19 constexpr METAL_FUNC
operator bool() const device {
22 constexpr METAL_FUNC
operator bool() const constant {
29template <
typename OutT,
typename InT = OutT>
33 METAL_FUNC OutT
apply(InT x)
const {
34 return static_cast<OutT
>(x) *
scale;
48 typename AccT =
float>
56 static_assert(SM * SN == 32,
"simdgroup can only have 32 threads");
59 SN == 8 || SN == 16 || SN == 32,
60 "gemv block must have a width of 8, 16, or 32");
62 static_assert(
blockN >=
blockM,
"Masked gemv must have blockN >= blockM");
95 template <
typename U = T>
96 static METAL_FUNC
void
97 load_unsafe(
const device T* src, thread U dst[TN],
const int src_offset = 0) {
99 for (
int tn = 0; tn < TN; tn++) {
100 dst[tn] =
static_cast<U
>(src[src_offset + tn]);
104 template <
typename U = T>
108 const int src_offset = 0,
109 const int src_size = TN) {
110 if (src_offset + TN <= src_size) {
112 for (
int tn = 0; tn < TN; tn++) {
113 dst[tn] =
static_cast<U
>(src[src_offset + tn]);
117 for (
int tn = 0; tn < TN; tn++) {
118 dst[tn] = src_offset + tn < src_size
119 ?
static_cast<U
>(src[src_offset + tn])
125 static METAL_FUNC
void run(
126 const device T* mat [[buffer(0)]],
127 const device T* in_vec [[buffer(1)]],
128 device T* out_vec [[buffer(3)]],
129 const constant
int& in_vec_size [[buffer(4)]],
130 const constant
int& out_vec_size [[buffer(5)]],
131 const constant
int& matrix_ld [[buffer(6)]],
132 const device out_mask_t* out_mask [[buffer(20)]],
133 const device op_mask_t* mat_mask [[buffer(21)]],
134 const device op_mask_t* vec_mask [[buffer(22)]],
135 const constant
int* mask_strides [[buffer(23)]],
136 threadgroup AccT* tgp_memory [[threadgroup(0)]],
137 uint3 tid [[threadgroup_position_in_grid]],
138 uint3 lid [[thread_position_in_threadgroup]],
139 uint simd_gid [[simdgroup_index_in_threadgroup]],
140 uint simd_lid [[thread_index_in_simdgroup]]) {
145 thread AccT result[TM] = {0};
147 thread AccT v_coeff[TN];
149 const int thrM = SN != 32 ? simd_lid / SN : 0;
150 const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
152 const int sgN = BN != 1 ? (simd_gid % BN) : 0;
154 const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
155 const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
157 int bm = (simdM + thrM) * TM;
158 int bn = (simdN + thrN) * TN;
161 int out_row = tid.x *
blockM + bm;
164 if (out_row >= out_vec_size)
168 out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
171 const constant
int* out_mask_strides = mask_strides;
172 const constant
int* mat_mask_strides =
174 const constant
int* vec_mask_strides =
179 const int out_mask_offset =
182 int mat_mask_offset =
184 int vec_mask_offset = 0;
192 auto mask_out = out_mask[out_mask_offset];
196 if (simdN == 0 && thrN == 0) {
198 for (
int tm = 0; tm < TM; tm++) {
199 out_vec[out_row + tm] = T(0.);
208 out_scale = T(mask_out);
213 mat += out_row * matrix_ld;
216 constexpr const uniform<int> loop_stride = make_uniform(
blockN);
217 const uniform<int> in_size = make_uniform(in_vec_size);
218 const uniform<int> n_iter = in_size / loop_stride;
219 const uniform<int> last_iter = loop_stride * n_iter;
220 const uniform<int> leftover = in_size - last_iter;
223 for (
int i = 0; i < n_iter; ++i) {
225 (
bool(mat_mask[mat_mask_offset]) &&
226 bool(vec_mask[vec_mask_offset]))) {
230 T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
238 for (
int tn = 0; tn < TN; tn++) {
239 v_coeff[tn] *= block_scale;
246 for (
int tm = 0; tm < TM; tm++) {
252 for (
int tn = 0; tn < TN; tn++) {
253 result[tm] += inter[tn] * v_coeff[tn];
256 mat_offset += matrix_ld;
261 mat_mask_offset += mat_mask_step;
262 vec_mask_offset += vec_mask_step;
267 (
bool(mat_mask[mat_mask_offset]) &&
268 bool(vec_mask[vec_mask_offset])))) {
272 T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
280 for (
int tn = 0; tn < TN; tn++) {
281 v_coeff[tn] *= block_scale;
287 for (
int tm = 0; tm < TM; tm++) {
289 load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
293 for (
int tn = 0; tn < TN; tn++) {
294 result[tm] += inter[tn] * v_coeff[tn];
302 for (
int tm = 0; tm < TM; tm++) {
303 result[tm] *= out_scale;
309 for (
int tm = 0; tm < TM; tm++) {
311 for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
318 threadgroup AccT* tgp_results = tgp_memory + sgN * (
blockM + TM) + bm;
321 for (
int tm = 0; tm < TM; tm++) {
322 tgp_results[tm] = result[tm];
325 threadgroup_barrier(mem_flags::mem_none);
329 for (
int sgn = 1; sgn < BN; sgn++) {
331 for (
int tm = 0; tm < TM; tm++) {
332 result[tm] += tgp_results[sgn * (
blockM + TM) + tm];
340 if (simdN == 0 && thrN == 0) {
342 for (
int tm = 0; tm < TM; tm++) {
343 out_vec[out_row + tm] =
static_cast<T
>(result[tm]);
363 typename AccT =
float>
371 static_assert(SM * SN == 32,
"simdgroup can only have 32 threads");
403 static METAL_FUNC
void run(
404 const device T* mat [[buffer(0)]],
405 const device T* in_vec [[buffer(1)]],
406 device T* out_vec [[buffer(3)]],
407 const constant
int& in_vec_size [[buffer(4)]],
408 const constant
int& out_vec_size [[buffer(5)]],
409 const constant
int& marix_ld [[buffer(6)]],
410 const device out_mask_t* out_mask [[buffer(20)]],
411 const device op_mask_t* mat_mask [[buffer(21)]],
412 const device op_mask_t* vec_mask [[buffer(22)]],
413 const constant
int* mask_strides [[buffer(23)]],
414 threadgroup AccT* tgp_memory [[threadgroup(0)]],
415 uint3 tid [[threadgroup_position_in_grid]],
416 uint3 lid [[thread_position_in_threadgroup]],
417 uint simd_gid [[simdgroup_index_in_threadgroup]],
418 uint simd_lid [[thread_index_in_simdgroup]]) {
423 AccT result[TN] = {0};
427 const int thrM = SN != 32 ? simd_lid / SN : 0;
428 const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
430 const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
431 const int sgN = BN != 1 ? (simd_gid % BN) : 0;
433 const int simdM = SM * sgM;
434 const int simdN = SN * sgN;
436 int cm = (simdM + thrM);
437 int cn = (simdN + thrN);
442 int out_col = tid.x *
blockN + bn;
445 const constant
int* out_mask_strides = mask_strides;
446 const constant
int* mat_mask_strides =
448 const constant
int* vec_mask_strides =
453 const int out_mask_offset =
456 int mat_mask_offset =
458 int vec_mask_offset = 0;
466 auto mask_out = out_mask[out_mask_offset];
470 if (cm == 0 && out_col < out_vec_size) {
471 if (out_col + TN <= out_vec_size) {
473 for (
int tn = 0; tn < TN; tn++) {
474 out_vec[out_col + tn] = T(0.);
477 for (
int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {
478 out_vec[out_col + tn] = T(0.);
488 out_scale = T(mask_out);
493 constexpr const uniform<int> loop_stride = make_uniform(
blockM);
494 const uniform<int> in_size = make_uniform(in_vec_size);
495 const uniform<int> n_iter = in_size / loop_stride;
496 const uniform<int> last_iter = loop_stride * n_iter;
497 const uniform<int> leftover = in_size - last_iter;
500 if (out_col < out_vec_size) {
501 out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;
504 for (
int i = 0; i < n_iter; ++i) {
507 threadgroup_barrier(mem_flags::mem_none);
510 (
bool(mat_mask[mat_mask_offset]) &&
511 bool(vec_mask[vec_mask_offset]))) {
515 T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
519 for (
int tm = 0; tm < TM; tm++) {
520 v_coeff[tm] =
static_cast<AccT
>(in_vec[bm + tm]);
526 for (
int tm = 0; tm < TM; tm++) {
527 v_coeff[tm] *= block_scale;
532 for (
int tm = 0; tm < TM; tm++) {
533 for (
int tn = 0; tn < TN; tn++) {
534 inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
536 for (
int tn = 0; tn < TN; tn++) {
537 result[tn] += v_coeff[tm] * inter[tn];
543 mat_mask_offset += mat_mask_step;
544 vec_mask_offset += vec_mask_step;
549 (
bool(mat_mask[mat_mask_offset]) &&
550 bool(vec_mask[vec_mask_offset])))) {
554 T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
557 for (
int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
558 v_coeff[tm] =
static_cast<AccT
>(in_vec[bm + tm]);
561 v_coeff[tm] *= block_scale;
565 for (
int tn = 0; tn < TN; tn++) {
566 inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
570 for (
int tn = 0; tn < TN; tn++) {
571 result[tn] += v_coeff[tm] * inter[tn];
580 for (
int tn = 0; tn < TN; tn++) {
581 result[tn] *= out_scale;
587 for (
int tn = 0; tn < TN; tn++) {
589 for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
596 threadgroup AccT* tgp_results = tgp_memory + sgM * (
blockN + TN) + bn;
599 for (
int tn = 0; tn < TN; tn++) {
600 tgp_results[tn] = result[tn];
603 threadgroup_barrier(mem_flags::mem_none);
607 for (
int sgm = 1; sgm < BM; sgm++) {
609 for (
int tn = 0; tn < TN; tn++) {
610 result[tn] += tgp_results[sgm * (
blockN + TN) + tn];
618 if (cm == 0 && out_col < out_vec_size) {
620 for (
int j = 0; j < TN; j++) {
621 out_vec[out_col + j] =
static_cast<T
>(result[j]);
641 const bool kDoNCBatch>
642[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]]
void gemv_masked(
643 const device T* mat [[buffer(0)]],
644 const device T* in_vec [[buffer(1)]],
645 device T* out_vec [[buffer(3)]],
646 const constant
int& in_vec_size [[buffer(4)]],
647 const constant
int& out_vec_size [[buffer(5)]],
648 const constant
int& marix_ld [[buffer(6)]],
649 const constant
int& batch_ndim [[buffer(9)]],
650 const constant
int* batch_shape [[buffer(10)]],
651 const constant int64_t* vector_batch_stride [[buffer(11)]],
652 const constant int64_t* matrix_batch_stride [[buffer(12)]],
653 const device out_mask_t* out_mask [[buffer(20)]],
654 const device op_mask_t* mat_mask [[buffer(21)]],
655 const device op_mask_t* vec_mask [[buffer(22)]],
656 const constant
int* mask_strides [[buffer(23)]],
657 const constant int64_t* mask_batch_strides [[buffer(24)]],
658 uint3 tid [[threadgroup_position_in_grid]],
659 uint3 lid [[thread_position_in_threadgroup]],
660 uint simd_gid [[simdgroup_index_in_threadgroup]],
661 uint simd_lid [[thread_index_in_simdgroup]]) {
664 threadgroup
float tgp_memory
665 [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
667 constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
668 constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
672 in_vec +=
elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
673 mat +=
elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
675 if (has_output_mask) {
677 elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
678 mask_batch_strides += batch_ndim;
681 if (has_operand_mask) {
682 const constant
auto* mask_strides_mat = mask_batch_strides;
683 const constant
auto* mask_strides_vec = mask_strides_mat + batch_ndim;
686 tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
688 mat_mask += batch_offsets.x;
689 vec_mask += batch_offsets.y;
693 in_vec += tid.z * vector_batch_stride[0];
694 mat += tid.z * matrix_batch_stride[0];
696 if (has_output_mask) {
697 out_mask += tid.z * mask_batch_strides[0];
698 mask_batch_strides += batch_ndim;
701 if (has_operand_mask) {
702 mat_mask += tid.z * mask_batch_strides[0];
703 vec_mask += tid.z * mask_batch_strides[batch_ndim];
707 out_vec += tid.z * out_vec_size;
720 gemv_kernel::tgp_mem_size == 0 ?
nullptr : tgp_memory,
741 const bool kDoNCBatch>
742[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]]
void gemv_t_masked(
743 const device T* mat [[buffer(0)]],
744 const device T* in_vec [[buffer(1)]],
745 device T* out_vec [[buffer(3)]],
746 const constant
int& in_vec_size [[buffer(4)]],
747 const constant
int& out_vec_size [[buffer(5)]],
748 const constant
int& marix_ld [[buffer(6)]],
749 const constant
int& batch_ndim [[buffer(9)]],
750 const constant
int* batch_shape [[buffer(10)]],
751 const constant int64_t* vector_batch_stride [[buffer(11)]],
752 const constant int64_t* matrix_batch_stride [[buffer(12)]],
753 const device out_mask_t* out_mask [[buffer(20)]],
754 const device op_mask_t* mat_mask [[buffer(21)]],
755 const device op_mask_t* vec_mask [[buffer(22)]],
756 const constant
int* mask_strides [[buffer(23)]],
757 const constant int64_t* mask_batch_strides [[buffer(24)]],
758 uint3 tid [[threadgroup_position_in_grid]],
759 uint3 lid [[thread_position_in_threadgroup]],
760 uint simd_gid [[simdgroup_index_in_threadgroup]],
761 uint simd_lid [[thread_index_in_simdgroup]]) {
764 threadgroup
float tgp_memory
765 [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
767 constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
768 constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
772 in_vec +=
elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
773 mat +=
elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
775 if (has_output_mask) {
777 elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
778 mask_batch_strides += batch_ndim;
781 if (has_operand_mask) {
782 const constant
auto* mask_strides_mat = mask_batch_strides;
783 const constant
auto* mask_strides_vec = mask_strides_mat + batch_ndim;
786 tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
788 mat_mask += batch_offsets.x;
789 vec_mask += batch_offsets.y;
793 in_vec += tid.z * vector_batch_stride[0];
794 mat += tid.z * matrix_batch_stride[0];
796 if (has_output_mask) {
797 out_mask += tid.z * mask_batch_strides[0];
798 mask_batch_strides += batch_ndim;
801 if (has_operand_mask) {
802 mat_mask += tid.z * mask_batch_strides[0];
803 vec_mask += tid.z * mask_batch_strides[batch_ndim];
807 out_vec += tid.z * out_vec_size;
820 gemv_kernel::tgp_mem_size == 0 ?
nullptr : tgp_memory,
#define MLX_MTL_CONST
Definition gemv_masked.h:7
#define MLX_MTL_PRAGMA_UNROLL
Definition gemv_masked.h:8
struct _NoMask nomask_t
Definition gemv_masked.h:27
void gemv_t_masked(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &marix_ld, const constant int &batch_ndim, const constant int *batch_shape, const constant int64_t *vector_batch_stride, const constant int64_t *matrix_batch_stride, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, const constant int64_t *mask_batch_strides, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Vector matrix multiplication.
Definition gemv_masked.h:742
void gemv_masked(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &marix_ld, const constant int &batch_ndim, const constant int *batch_shape, const constant int64_t *vector_batch_stride, const constant int64_t *matrix_batch_stride, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, const constant int64_t *mask_batch_strides, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Matrix vector multiplication.
Definition gemv_masked.h:642
Definition gemv_masked.h:10
char x
Definition gemv_masked.h:11
Definition gemv_masked.h:49
static METAL_FUNC void load_safe(const device T *src, thread U dst[TN], const int src_offset=0, const int src_size=TN)
Definition gemv_masked.h:105
static constant constexpr const bool has_mul_operand_mask
Definition gemv_masked.h:67
static constant constexpr const int blockM
Definition gemv_masked.h:53
static constant constexpr const int threadsN
Definition gemv_masked.h:51
static constant constexpr const bool has_output_mask
Definition gemv_masked.h:65
static constant constexpr const short tgp_mem_size
Definition gemv_masked.h:92
static constant constexpr const bool has_operand_mask
Definition gemv_masked.h:64
static METAL_FUNC void run(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &matrix_ld, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, threadgroup AccT *tgp_memory, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Definition gemv_masked.h:125
static constant constexpr const bool has_mul_output_mask
Definition gemv_masked.h:69
static constant constexpr const bool needs_tgp_reduction
Definition gemv_masked.h:93
static constant constexpr const int blockN
Definition gemv_masked.h:54
static METAL_FUNC void load_unsafe(const device T *src, thread U dst[TN], const int src_offset=0)
Definition gemv_masked.h:97
static constant constexpr const int threadsM
Definition gemv_masked.h:50
Vector matrix multiplication.
Definition gemv_masked.h:364
static constant constexpr const int threadsM
Definition gemv_masked.h:365
static constant constexpr const bool needs_tgp_reduction
Definition gemv_masked.h:401
static constant constexpr const bool has_output_mask
Definition gemv_masked.h:374
static constant constexpr const bool has_mul_output_mask
Definition gemv_masked.h:378
static METAL_FUNC void run(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &marix_ld, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, threadgroup AccT *tgp_memory, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Definition gemv_masked.h:403
static constant constexpr const int blockN
Definition gemv_masked.h:369
static constant constexpr const int threadsN
Definition gemv_masked.h:366
static constant constexpr const short tgp_mem_size
Definition gemv_masked.h:400
static constant constexpr const bool has_mul_operand_mask
Definition gemv_masked.h:376
static constant constexpr const int blockM
Definition gemv_masked.h:368
static constant constexpr const bool has_operand_mask
Definition gemv_masked.h:373
Definition gemv_masked.h:30
OutT scale
Definition gemv_masked.h:31
METAL_FUNC OutT apply(InT x) const
Definition gemv_masked.h:33