7#define MLX_MTL_CONST static constant constexpr const 
    8#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") 
   13  constexpr METAL_FUNC 
operator bool() {
 
 
   16  constexpr METAL_FUNC 
operator bool() const threadgroup {
 
 
   19  constexpr METAL_FUNC 
operator bool() const device {
 
 
   22  constexpr METAL_FUNC 
operator bool() const constant {
 
 
 
   29template <
typename OutT, 
typename InT = OutT>
 
   33  METAL_FUNC OutT 
apply(InT x)
 const {
 
   34    return static_cast<OutT
>(x) * 
scale;
 
 
 
   55  static_assert(SM * SN == 32, 
"simdgroup can only have 32 threads");
 
   58      SN == 8 || SN == 16 || SN == 32,
 
   59      "gemv block must have a width of 8, 16, or 32");
 
   61  static_assert(
blockN >= 
blockM, 
"Masked gemv must have blockN >= blockM");
 
   94  static METAL_FUNC 
void 
   95  load_unsafe(
const device T* src, thread T dst[TN], 
const int src_offset = 0) {
 
   97    for (
int tn = 0; tn < TN; tn++) {
 
   98      dst[tn] = src[src_offset + tn];
 
 
  105      const int src_offset = 0,
 
  106      const int src_size = TN) {
 
  107    if (src_offset + TN <= src_size) {
 
  109      for (
int tn = 0; tn < TN; tn++) {
 
  110        dst[tn] = src[src_offset + tn];
 
  114      for (
int tn = 0; tn < TN; tn++) {
 
  115        dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0;
 
 
  120  static METAL_FUNC 
void run(
 
  121      const device T* mat [[buffer(0)]],
 
  122      const device T* in_vec [[buffer(1)]],
 
  123      device T* out_vec [[buffer(3)]],
 
  124      const constant 
int& in_vec_size [[buffer(4)]],
 
  125      const constant 
int& out_vec_size [[buffer(5)]],
 
  126      const constant 
int& matrix_ld [[buffer(6)]],
 
  127      const device out_mask_t* out_mask [[buffer(20)]],
 
  128      const device op_mask_t* mat_mask [[buffer(21)]],
 
  129      const device op_mask_t* vec_mask [[buffer(22)]],
 
  130      const constant 
int* mask_strides [[buffer(23)]],
 
  131      threadgroup T* tgp_memory [[threadgroup(0)]],
 
  132      uint3 tid [[threadgroup_position_in_grid]],
 
  133      uint3 lid [[thread_position_in_threadgroup]],
 
  134      uint simd_gid [[simdgroup_index_in_threadgroup]],
 
  135      uint simd_lid [[thread_index_in_simdgroup]]) {
 
  140    thread T result[TM] = {0};
 
  142    thread T v_coeff[TN];
 
  144    const int thrM = SN != 32 ? simd_lid / SN : 0;
 
  145    const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
 
  147    const int sgN = BN != 1 ? (simd_gid % BN) : 0;
 
  149    const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
 
  150    const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
 
  152    int bm = (simdM + thrM) * TM;
 
  153    int bn = (simdN + thrN) * TN;
 
  156    int out_row = tid.x * 
blockM + bm;
 
  159    if (out_row >= out_vec_size)
 
  163    out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
 
  166    const constant 
int* out_mask_strides = mask_strides;
 
  167    const constant 
int* mat_mask_strides =
 
  169    const constant 
int* vec_mask_strides =
 
  174    const int out_mask_offset =
 
  177    int mat_mask_offset =
 
  179    int vec_mask_offset = 0;
 
  187      auto mask_out = out_mask[out_mask_offset];
 
  191        if (simdN == 0 && thrN == 0) {
 
  193          for (
int tm = 0; tm < TM; tm++) {
 
  194            out_vec[out_row + tm] = T(0.);
 
  203        out_scale = T(mask_out);
 
  208    mat += out_row * matrix_ld;
 
  211    constexpr const uniform<int> loop_stride = make_uniform(
blockN);
 
  212    const uniform<int> in_size = make_uniform(in_vec_size);
 
  213    const uniform<int> n_iter = in_size / loop_stride;
 
  214    const uniform<int> last_iter = loop_stride * n_iter;
 
  215    const uniform<int> leftover = in_size - last_iter;
 
  218    for (
int i = 0; i < n_iter; ++i) {
 
  220          (
bool(mat_mask[mat_mask_offset]) &&
 
  221           bool(vec_mask[vec_mask_offset]))) {
 
  225              T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
 
  233          for (
int tn = 0; tn < TN; tn++) {
 
  234            v_coeff[tn] *= block_scale;
 
  241        for (
int tm = 0; tm < TM; tm++) {
 
  247          for (
int tn = 0; tn < TN; tn++) {
 
  248            result[tm] += inter[tn] * v_coeff[tn];
 
  251          mat_offset += matrix_ld;
 
  256      mat_mask_offset += mat_mask_step;
 
  257      vec_mask_offset += vec_mask_step;
 
  262         (
bool(mat_mask[mat_mask_offset]) &&
 
  263          bool(vec_mask[vec_mask_offset])))) {
 
  267            T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
 
  275        for (
int tn = 0; tn < TN; tn++) {
 
  276          v_coeff[tn] *= block_scale;
 
  282      for (
int tm = 0; tm < TM; tm++) {
 
  284        load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
 
  288        for (
int tn = 0; tn < TN; tn++) {
 
  289          result[tm] += inter[tn] * v_coeff[tn];
 
  297      for (
int tm = 0; tm < TM; tm++) {
 
  298        result[tm] *= out_scale;
 
  304    for (
int tm = 0; tm < TM; tm++) {
 
  306      for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
 
  313      threadgroup T* tgp_results = tgp_memory + sgN * (
blockM + TM) + bm;
 
  316        for (
int tm = 0; tm < TM; tm++) {
 
  317          tgp_results[tm] = result[tm];
 
  320        threadgroup_barrier(mem_flags::mem_none);
 
  324          for (
int sgn = 1; sgn < BN; sgn++) {
 
  326            for (
int tm = 0; tm < TM; tm++) {
 
  327              result[tm] += tgp_results[sgn * (
blockM + TM) + tm];
 
  335    if (simdN == 0 && thrN == 0) {
 
  337      for (
int tm = 0; tm < TM; tm++) {
 
  338        out_vec[out_row + tm] = result[tm];
 
 
 
  365  static_assert(SM * SN == 32, 
"simdgroup can only have 32 threads");
 
  397  static METAL_FUNC 
void run(
 
  398      const device T* mat [[buffer(0)]],
 
  399      const device T* in_vec [[buffer(1)]],
 
  400      device T* out_vec [[buffer(3)]],
 
  401      const constant 
int& in_vec_size [[buffer(4)]],
 
  402      const constant 
int& out_vec_size [[buffer(5)]],
 
  403      const constant 
int& marix_ld [[buffer(6)]],
 
  404      const device out_mask_t* out_mask [[buffer(20)]],
 
  405      const device op_mask_t* mat_mask [[buffer(21)]],
 
  406      const device op_mask_t* vec_mask [[buffer(22)]],
 
  407      const constant 
int* mask_strides [[buffer(23)]],
 
  408      threadgroup T* tgp_memory [[threadgroup(0)]],
 
  409      uint3 tid [[threadgroup_position_in_grid]],
 
  410      uint3 lid [[thread_position_in_threadgroup]],
 
  411      uint simd_gid [[simdgroup_index_in_threadgroup]],
 
  412      uint simd_lid [[thread_index_in_simdgroup]]) {
 
  421    const int thrM = SN != 32 ? simd_lid / SN : 0;
 
  422    const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
 
  424    const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
 
  425    const int sgN = BN != 1 ? (simd_gid % BN) : 0;
 
  427    const int simdM = SM * sgM;
 
  428    const int simdN = SN * sgN;
 
  430    int cm = (simdM + thrM);
 
  431    int cn = (simdN + thrN);
 
  436    int out_col = tid.x * 
blockN + bn;
 
  439    const constant 
int* out_mask_strides = mask_strides;
 
  440    const constant 
int* mat_mask_strides =
 
  442    const constant 
int* vec_mask_strides =
 
  447    const int out_mask_offset =
 
  450    int mat_mask_offset =
 
  452    int vec_mask_offset = 0;
 
  460      auto mask_out = out_mask[out_mask_offset];
 
  464        if (cm == 0 && out_col < out_vec_size) {
 
  465          if (out_col + TN <= out_vec_size) {
 
  467            for (
int tn = 0; tn < TN; tn++) {
 
  468              out_vec[out_col + tn] = T(0.);
 
  471            for (
int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {
 
  472              out_vec[out_col + tn] = T(0.);
 
  482        out_scale = T(mask_out);
 
  487    constexpr const uniform<int> loop_stride = make_uniform(
blockM);
 
  488    const uniform<int> in_size = make_uniform(in_vec_size);
 
  489    const uniform<int> n_iter = in_size / loop_stride;
 
  490    const uniform<int> last_iter = loop_stride * n_iter;
 
  491    const uniform<int> leftover = in_size - last_iter;
 
  494    if (out_col < out_vec_size) {
 
  495      out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;
 
  498      for (
int i = 0; i < n_iter; ++i) {
 
  501        threadgroup_barrier(mem_flags::mem_none);
 
  504            (
bool(mat_mask[mat_mask_offset]) &&
 
  505             bool(vec_mask[vec_mask_offset]))) {
 
  509                T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
 
  513          for (
int tm = 0; tm < TM; tm++) {
 
  514            v_coeff[tm] = in_vec[bm + tm];
 
  520            for (
int tm = 0; tm < TM; tm++) {
 
  521              v_coeff[tm] *= block_scale;
 
  526          for (
int tm = 0; tm < TM; tm++) {
 
  527            for (
int tn = 0; tn < TN; tn++) {
 
  528              inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
 
  530            for (
int tn = 0; tn < TN; tn++) {
 
  531              result[tn] += v_coeff[tm] * inter[tn];
 
  537        mat_mask_offset += mat_mask_step;
 
  538        vec_mask_offset += vec_mask_step;
 
  543           (
bool(mat_mask[mat_mask_offset]) &&
 
  544            bool(vec_mask[vec_mask_offset])))) {
 
  548              T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
 
  551        for (
int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
 
  552          v_coeff[tm] = in_vec[bm + tm];
 
  555            v_coeff[tm] *= block_scale;
 
  559          for (
int tn = 0; tn < TN; tn++) {
 
  560            inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
 
  564          for (
int tn = 0; tn < TN; tn++) {
 
  565            result[tn] += v_coeff[tm] * inter[tn];
 
  574      for (
int tn = 0; tn < TN; tn++) {
 
  575        result[tn] *= out_scale;
 
  581    for (
int tn = 0; tn < TN; tn++) {
 
  583      for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
 
  590      threadgroup T* tgp_results = tgp_memory + sgM * (
blockN + TN) + bn;
 
  593        for (
int tn = 0; tn < TN; tn++) {
 
  594          tgp_results[tn] = result[tn];
 
  597        threadgroup_barrier(mem_flags::mem_none);
 
  601          for (
int sgm = 1; sgm < BM; sgm++) {
 
  603            for (
int tn = 0; tn < TN; tn++) {
 
  604              result[tn] += tgp_results[sgm * (
blockN + TN) + tn];
 
  612    if (cm == 0 && out_col < out_vec_size) {
 
  614      for (
int j = 0; j < TN; j++) {
 
  615        out_vec[out_col + j] = result[j];
 
 
 
  635    const bool kDoNCBatch> 
 
  636[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] 
void gemv_masked(
 
  637    const device T* mat [[buffer(0)]],
 
  638    const device T* in_vec [[buffer(1)]],
 
  639    device T* out_vec [[buffer(3)]],
 
  640    const constant 
int& in_vec_size [[buffer(4)]],
 
  641    const constant 
int& out_vec_size [[buffer(5)]],
 
  642    const constant 
int& marix_ld [[buffer(6)]],
 
  643    const constant 
int& batch_ndim [[buffer(9)]],
 
  644    const constant 
int* batch_shape [[buffer(10)]],
 
  645    const constant 
size_t* vector_batch_stride [[buffer(11)]],
 
  646    const constant 
size_t* matrix_batch_stride [[buffer(12)]],
 
  647    const device out_mask_t* out_mask [[buffer(20)]],
 
  648    const device op_mask_t* mat_mask [[buffer(21)]],
 
  649    const device op_mask_t* vec_mask [[buffer(22)]],
 
  650    const constant 
int* mask_strides [[buffer(23)]],
 
  651    const constant 
size_t* mask_batch_strides [[buffer(24)]],
 
  652    uint3 tid [[threadgroup_position_in_grid]],
 
  653    uint3 lid [[thread_position_in_threadgroup]],
 
  654    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
  655    uint simd_lid [[thread_index_in_simdgroup]]) {
 
  658  threadgroup T tgp_memory
 
  659      [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
 
  661  constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
 
  662  constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
 
  666    in_vec += 
elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
 
  667    mat += 
elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
 
  669    if (has_output_mask) {
 
  671          elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
 
  672      mask_batch_strides += batch_ndim;
 
  675    if (has_operand_mask) {
 
  676      const constant 
size_t* mask_strides_mat = mask_batch_strides;
 
  677      const constant 
size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
 
  680          tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
 
  682      mat_mask += batch_offsets.x;
 
  683      vec_mask += batch_offsets.y;
 
  687    in_vec += tid.z * vector_batch_stride[0];
 
  688    mat += tid.z * matrix_batch_stride[0];
 
  690    if (has_output_mask) {
 
  691      out_mask += tid.z * mask_batch_strides[0];
 
  692      mask_batch_strides += batch_ndim;
 
  695    if (has_operand_mask) {
 
  696      mat_mask += tid.z * mask_batch_strides[0];
 
  697      vec_mask += tid.z * mask_batch_strides[batch_ndim];
 
  701  out_vec += tid.z * out_vec_size;
 
  714      gemv_kernel::tgp_mem_size == 0 ? 
nullptr : tgp_memory,
 
 
  735    const bool kDoNCBatch> 
 
  736[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] 
void gemv_t_masked(
 
  737    const device T* mat [[buffer(0)]],
 
  738    const device T* in_vec [[buffer(1)]],
 
  739    device T* out_vec [[buffer(3)]],
 
  740    const constant 
int& in_vec_size [[buffer(4)]],
 
  741    const constant 
int& out_vec_size [[buffer(5)]],
 
  742    const constant 
int& marix_ld [[buffer(6)]],
 
  743    const constant 
int& batch_ndim [[buffer(9)]],
 
  744    const constant 
int* batch_shape [[buffer(10)]],
 
  745    const constant 
size_t* vector_batch_stride [[buffer(11)]],
 
  746    const constant 
size_t* matrix_batch_stride [[buffer(12)]],
 
  747    const device out_mask_t* out_mask [[buffer(20)]],
 
  748    const device op_mask_t* mat_mask [[buffer(21)]],
 
  749    const device op_mask_t* vec_mask [[buffer(22)]],
 
  750    const constant 
int* mask_strides [[buffer(23)]],
 
  751    const constant 
size_t* mask_batch_strides [[buffer(24)]],
 
  752    uint3 tid [[threadgroup_position_in_grid]],
 
  753    uint3 lid [[thread_position_in_threadgroup]],
 
  754    uint simd_gid [[simdgroup_index_in_threadgroup]],
 
  755    uint simd_lid [[thread_index_in_simdgroup]]) {
 
  758  threadgroup T tgp_memory
 
  759      [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
 
  761  constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
 
  762  constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
 
  766    in_vec += 
elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
 
  767    mat += 
elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
 
  769    if (has_output_mask) {
 
  771          elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
 
  772      mask_batch_strides += batch_ndim;
 
  775    if (has_operand_mask) {
 
  776      const constant 
size_t* mask_strides_mat = mask_batch_strides;
 
  777      const constant 
size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
 
  780          tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
 
  782      mat_mask += batch_offsets.x;
 
  783      vec_mask += batch_offsets.y;
 
  787    in_vec += tid.z * vector_batch_stride[0];
 
  788    mat += tid.z * matrix_batch_stride[0];
 
  790    if (has_output_mask) {
 
  791      out_mask += tid.z * mask_batch_strides[0];
 
  792      mask_batch_strides += batch_ndim;
 
  795    if (has_operand_mask) {
 
  796      mat_mask += tid.z * mask_batch_strides[0];
 
  797      vec_mask += tid.z * mask_batch_strides[batch_ndim];
 
  801  out_vec += tid.z * out_vec_size;
 
  814      gemv_kernel::tgp_mem_size == 0 ? 
nullptr : tgp_memory,
 
 
#define MLX_MTL_CONST
Definition gemv_masked.h:7
 
#define MLX_MTL_PRAGMA_UNROLL
Definition gemv_masked.h:8
 
void gemv_t_masked(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &marix_ld, const constant int &batch_ndim, const constant int *batch_shape, const constant size_t *vector_batch_stride, const constant size_t *matrix_batch_stride, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, const constant size_t *mask_batch_strides, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Vector matrix multiplication.
Definition gemv_masked.h:736
 
void gemv_masked(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &marix_ld, const constant int &batch_ndim, const constant int *batch_shape, const constant size_t *vector_batch_stride, const constant size_t *matrix_batch_stride, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, const constant size_t *mask_batch_strides, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Matrix vector multiplication.
Definition gemv_masked.h:636
 
Definition gemv_masked.h:10
 
char x
Definition gemv_masked.h:11
 
Definition gemv_masked.h:48
 
static METAL_FUNC void load_safe(const device T *src, thread T dst[TN], const int src_offset=0, const int src_size=TN)
Definition gemv_masked.h:102
 
static constant constexpr const bool has_mul_output_mask
Definition gemv_masked.h:68
 
static constant constexpr const int threadsM
Definition gemv_masked.h:49
 
static constant constexpr const int blockN
Definition gemv_masked.h:53
 
static constant constexpr const int threadsN
Definition gemv_masked.h:50
 
static METAL_FUNC void load_unsafe(const device T *src, thread T dst[TN], const int src_offset=0)
Definition gemv_masked.h:95
 
static constant constexpr const int blockM
Definition gemv_masked.h:52
 
static constant constexpr const short tgp_mem_size
Definition gemv_masked.h:91
 
static constant constexpr const bool has_operand_mask
Definition gemv_masked.h:63
 
static constant constexpr const bool has_output_mask
Definition gemv_masked.h:64
 
static METAL_FUNC void run(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &matrix_ld, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, threadgroup T *tgp_memory, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Definition gemv_masked.h:120
 
static constant constexpr const bool has_mul_operand_mask
Definition gemv_masked.h:66
 
static constant constexpr const bool needs_tgp_reduction
Definition gemv_masked.h:92
 
Vector matrix multiplication.
Definition gemv_masked.h:358
 
static constant constexpr const int blockM
Definition gemv_masked.h:362
 
static constant constexpr const short tgp_mem_size
Definition gemv_masked.h:394
 
static constant constexpr const int threadsM
Definition gemv_masked.h:359
 
static METAL_FUNC void run(const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &marix_ld, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, threadgroup T *tgp_memory, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
Definition gemv_masked.h:397
 
static constant constexpr const int blockN
Definition gemv_masked.h:363
 
static constant constexpr const bool has_operand_mask
Definition gemv_masked.h:367
 
static constant constexpr const bool needs_tgp_reduction
Definition gemv_masked.h:395
 
static constant constexpr const bool has_mul_operand_mask
Definition gemv_masked.h:370
 
static constant constexpr const bool has_mul_output_mask
Definition gemv_masked.h:372
 
static constant constexpr const bool has_output_mask
Definition gemv_masked.h:368
 
static constant constexpr const int threadsN
Definition gemv_masked.h:360
 
Definition gemv_masked.h:30
 
OutT scale
Definition gemv_masked.h:31
 
METAL_FUNC OutT apply(InT x) const
Definition gemv_masked.h:33