54    const device T* A [[buffer(0)]],
 
   55    const device T* B [[buffer(1)]],
 
   56    device T* D [[buffer(3)]],
 
   57    const constant 
GEMMParams* params [[buffer(4)]],
 
   58    const constant 
int* batch_shape [[buffer(6)]],
 
   59    const constant int64_t* batch_strides [[buffer(7)]],
 
   60    const device out_mask_t* out_mask [[buffer(10)]],
 
   61    const device op_mask_t* lhs_mask [[buffer(11)]],
 
   62    const device op_mask_t* rhs_mask [[buffer(12)]],
 
   63    const constant 
int* mask_strides [[buffer(13)]],
 
   64    uint simd_lane_id [[thread_index_in_simdgroup]],
 
   65    uint simd_group_id [[simdgroup_index_in_threadgroup]],
 
   66    uint3 tid [[threadgroup_position_in_grid]],
 
   67    uint3 lid [[thread_position_in_threadgroup]]) {
 
   73      "block_masked_gemm must have the same block M and block N size");
 
   74  static_assert(BM % BK == 0, 
"block_masked_gemm must have BM % BK == 0");
 
   76  constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
 
   77  constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
 
   79  constexpr bool has_mul_operand_mask =
 
   80      has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
 
   81  constexpr bool has_mul_output_mask =
 
   82      has_output_mask && !metal::is_same_v<out_mask_t, bool>;
 
   84  constexpr short k_mask_factor = short(BM / BK);
 
   99  const int tid_y = ((tid.y) << params->swizzle_log) +
 
  100      ((tid.x) & ((1 << params->swizzle_log) - 1));
 
  101  const int tid_x = (tid.x) >> params->swizzle_log;
 
  103  if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
 
  107  const constant 
auto* mask_batch_strides =
 
  108      batch_strides + 2 * params->batch_ndim;
 
  110  if (params->batch_ndim > 1) {
 
  111    if (has_output_mask) {
 
  113          tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
 
  115      mask_batch_strides += params->batch_ndim;
 
  118    if (has_operand_mask) {
 
  119      const constant 
auto* mask_strides_lhs = mask_batch_strides;
 
  120      const constant 
auto* mask_strides_rhs =
 
  121          mask_strides_lhs + params->batch_ndim;
 
  130      lhs_mask += batch_offsets.x;
 
  131      rhs_mask += batch_offsets.y;
 
  134    if (has_output_mask) {
 
  135      out_mask += tid.z * mask_batch_strides[0];
 
  136      mask_batch_strides += params->batch_ndim;
 
  139    if (has_operand_mask) {
 
  140      lhs_mask += tid.z * mask_batch_strides[0];
 
  141      rhs_mask += tid.z * mask_batch_strides[params->batch_ndim];
 
  146  if (params->batch_ndim > 1) {
 
  147    const constant 
auto* A_bstrides = batch_strides;
 
  148    const constant 
auto* B_bstrides = batch_strides + params->batch_ndim;
 
  151        tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
 
  153    A += batch_offsets.x;
 
  154    B += batch_offsets.y;
 
  157    A += params->batch_stride_a * tid.z;
 
  158    B += params->batch_stride_b * tid.z;
 
  161  D += params->batch_stride_d * tid.z;
 
  164  const int c_row = tid_y * BM;
 
  165  const int c_col = tid_x * BN;
 
  166  const size_t c_row_long = size_t(c_row);
 
  167  const size_t c_col_long = size_t(c_col);
 
  169  A += transpose_a ? c_row_long : c_row_long * params->lda;
 
  170  B += transpose_b ? c_col_long * params->ldb : c_col_long;
 
  171  D += c_row_long * params->ldd + c_col_long;
 
  173  const constant 
int* out_mask_strides = mask_strides;
 
  174  const constant 
int* lhs_mask_strides =
 
  175      mask_strides + (has_output_mask ? 2 : 0);
 
  176  const constant 
int* rhs_mask_strides =
 
  177      lhs_mask_strides + (has_operand_mask ? 2 : 0);
 
  179  const int out_mask_offset = !has_output_mask
 
  181      : tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0];
 
  182  int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1];
 
  183  int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0];
 
  184  const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0];
 
  185  const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1];
 
  186  short k_factor_cnt = k_mask_factor;
 
  192  if (has_output_mask) {
 
  193    auto mask_out = out_mask[out_mask_offset];
 
  195    if (has_mul_output_mask) {
 
  196      out_mask_op.
scale = float(mask_out);
 
  201      constexpr short tgp_size = WM * WN * 32;
 
  202      constexpr short vec_size = 4;
 
  205      constexpr short TN = BN / vec_size;
 
  206      constexpr short TM = tgp_size / TN;
 
  208      const short thread_idx = simd_group_id * 32 + simd_lane_id;
 
  209      const short bi = thread_idx / TN;
 
  210      const short bj = vec_size * (thread_idx % TN);
 
  212      D += bi * params->ldd + bj;
 
  214      short tgp_bm = 
min(BM, params->M - c_row);
 
  215      short tgp_bn = 
min(BN, params->N - c_col);
 
  217      if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
 
  218        for (
short ti = 0; ti < BM; ti += TM) {
 
  220          for (
short j = 0; j < vec_size; j++) {
 
  221            D[ti * params->ldd + j] = T(0.);
 
  225        short jmax = tgp_bn - bj;
 
  226        jmax = jmax < vec_size ? jmax : vec_size;
 
  227        for (
short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
 
  228          for (
short j = 0; j < jmax; j++) {
 
  229            D[ti * params->ldd + j] = T(0.);
 
  238  threadgroup_barrier(mem_flags::mem_none);
 
  241  thread 
typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
 
  243  threadgroup T As[gemm_kernel::tgp_mem_size_a];
 
  244  threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
 
  247  thread 
typename gemm_kernel::loader_a_t loader_a(
 
  248      A, params->lda, As, simd_group_id, simd_lane_id);
 
  249  thread 
typename gemm_kernel::loader_b_t loader_b(
 
  250      B, params->ldb, Bs, simd_group_id, simd_lane_id);
 
  254      MN_aligned ? short(BM) : short(
min(BM, params->M - c_row));
 
  256      MN_aligned ? short(BN) : short(
min(BN, params->N - c_col));
 
  258  int gemm_k_iterations = params->gemm_k_iterations_aligned;
 
  263    const int k_last = params->gemm_k_iterations_aligned * BK;
 
  264    const int mask_idx_last = k_last / BM;
 
  266    if (!has_operand_mask ||
 
  267        (
bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) &&
 
  268         bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) {
 
  269      if (has_mul_operand_mask) {
 
  271            lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step];
 
  273            rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step];
 
  277      const int k_remain = params->K - k_last;
 
  278      const size_t k_jump_a =
 
  279          transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
 
  280      const size_t k_jump_b =
 
  281          transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
 
  283      loader_a.src += k_jump_a;
 
  284      loader_b.src += k_jump_b;
 
  287      const short2 tile_dims_A =
 
  288          transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
 
  289      const short2 tile_dims_B =
 
  290          transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
 
  292      loader_a.load_safe(tile_dims_A);
 
  293      loader_b.load_safe(tile_dims_B);
 
  295      if (has_mul_operand_mask) {
 
  296        loader_a.apply_inplace_op(lhs_mask_op);
 
  297        loader_b.apply_inplace_op(rhs_mask_op);
 
  300      threadgroup_barrier(mem_flags::mem_threadgroup);
 
  306      loader_a.src -= k_jump_a;
 
  307      loader_b.src -= k_jump_b;
 
  314    for (; gemm_k_iterations > 0; gemm_k_iterations--) {
 
  315      threadgroup_barrier(mem_flags::mem_threadgroup);
 
  317      if (!has_operand_mask ||
 
  318          (
bool(lhs_mask[lhs_mask_offset]) &&
 
  319           bool(rhs_mask[rhs_mask_offset]))) {
 
  320        if (has_mul_operand_mask) {
 
  321          lhs_mask_op.
scale = lhs_mask[lhs_mask_offset];
 
  322          rhs_mask_op.
scale = rhs_mask[rhs_mask_offset];
 
  326        loader_a.load_unsafe();
 
  327        loader_b.load_unsafe();
 
  329        if (has_mul_operand_mask) {
 
  330          loader_a.apply_inplace_op(lhs_mask_op);
 
  331          loader_b.apply_inplace_op(rhs_mask_op);
 
  334        threadgroup_barrier(mem_flags::mem_threadgroup);
 
  345      lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
 
  346      rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
 
  347      k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
 
  350    if (has_mul_output_mask) {
 
  351      mma_op.apply_epilogue(out_mask_op);
 
  355    mma_op.store_result(D, params->ldd);
 
  362    const bool M_aligned = (tgp_bm == BM);
 
  363    const bool N_aligned = (tgp_bn == BN);
 
  365    const short2 tile_dims_A =
 
  366        transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
 
  367    const short2 tile_dims_B =
 
  368        transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
 
  370    for (; gemm_k_iterations > 0; gemm_k_iterations--) {
 
  371      threadgroup_barrier(mem_flags::mem_threadgroup);
 
  372      if (!has_operand_mask ||
 
  373          (
bool(lhs_mask[lhs_mask_offset]) &&
 
  374           bool(rhs_mask[rhs_mask_offset]))) {
 
  375        if (has_mul_operand_mask) {
 
  376          lhs_mask_op.
scale = lhs_mask[lhs_mask_offset];
 
  377          rhs_mask_op.
scale = rhs_mask[rhs_mask_offset];
 
  382          loader_a.load_unsafe();
 
  384          loader_a.load_safe(tile_dims_A);
 
  388          loader_b.load_unsafe();
 
  390          loader_b.load_safe(tile_dims_B);
 
  393        if (has_mul_operand_mask) {
 
  394          loader_a.apply_inplace_op(lhs_mask_op);
 
  395          loader_b.apply_inplace_op(rhs_mask_op);
 
  398        threadgroup_barrier(mem_flags::mem_threadgroup);
 
  409      lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
 
  410      rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
 
  411      k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
 
  414    if (has_mul_output_mask) {
 
  415      mma_op.apply_epilogue(out_mask_op);
 
  418    if (M_aligned && N_aligned) {
 
  419      mma_op.store_result(D, params->ldd);
 
  421      mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
 
 
  440    const device T* A [[buffer(0)]],
 
  441    const device T* B [[buffer(1)]],
 
  442    device T* D [[buffer(3)]],
 
  443    const constant 
GEMMParams* params [[buffer(4)]],
 
  444    const constant 
int* batch_shape [[buffer(6)]],
 
  445    const constant int64_t* batch_strides [[buffer(7)]],
 
  446    const device 
bool* out_mask [[buffer(10)]],
 
  447    const device 
bool* lhs_mask [[buffer(11)]],
 
  448    const device 
bool* rhs_mask [[buffer(12)]],
 
  449    const constant 
int* mask_strides [[buffer(13)]],
 
  450    uint simd_lane_id [[thread_index_in_simdgroup]],
 
  451    uint simd_group_id [[simdgroup_index_in_threadgroup]],
 
  452    uint3 tid [[threadgroup_position_in_grid]],
 
  453    uint3 lid [[thread_position_in_threadgroup]]) {
 
  470  const int tid_y = ((tid.y) << params->swizzle_log) +
 
  471      ((tid.x) & ((1 << params->swizzle_log) - 1));
 
  472  const int tid_x = (tid.x) >> params->swizzle_log;
 
  474  if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
 
  478  if (params->batch_ndim > 1) {
 
  479    const constant 
auto* mask_batch_strides =
 
  480        batch_strides + 2 * params->batch_ndim;
 
  482        elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
 
  484    if (has_operand_mask) {
 
  485      const constant 
auto* mask_strides_lhs =
 
  486          mask_batch_strides + params->batch_ndim;
 
  487      const constant 
auto* mask_strides_rhs =
 
  488          mask_strides_lhs + params->batch_ndim;
 
  497      lhs_mask += batch_offsets.x;
 
  498      rhs_mask += batch_offsets.y;
 
  501    out_mask += tid.z * batch_strides[2 * params->batch_ndim];
 
  502    if (has_operand_mask) {
 
  503      lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
 
  504      rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
 
  509  if (params->batch_ndim > 1) {
 
  510    const constant 
auto* A_bstrides = batch_strides;
 
  511    const constant 
auto* B_bstrides = batch_strides + params->batch_ndim;
 
  514        tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
 
  516    A += batch_offsets.x;
 
  517    B += batch_offsets.y;
 
  520    A += params->batch_stride_a * tid.z;
 
  521    B += params->batch_stride_b * tid.z;
 
  524  D += params->batch_stride_d * tid.z;
 
  527  const int c_row = tid_y * BM;
 
  528  const int c_col = tid_x * BN;
 
  529  const size_t c_row_long = size_t(c_row);
 
  530  const size_t c_col_long = size_t(c_col);
 
  532  A += transpose_a ? c_row_long : c_row_long * params->lda;
 
  533  B += transpose_b ? c_col_long * params->ldb : c_col_long;
 
  534  D += c_row_long * params->ldd + c_col_long;
 
  536  bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
 
  540    constexpr short tgp_size = WM * WN * 32;
 
  541    constexpr short vec_size = 4;
 
  544    constexpr short TN = BN / vec_size;
 
  545    constexpr short TM = tgp_size / TN;
 
  547    const short thread_idx = simd_group_id * 32 + simd_lane_id;
 
  548    const short bi = thread_idx / TN;
 
  549    const short bj = vec_size * (thread_idx % TN);
 
  551    D += bi * params->ldd + bj;
 
  553    short tgp_bm = 
min(BM, params->M - c_row);
 
  554    short tgp_bn = 
min(BN, params->N - c_col);
 
  556    if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
 
  557      for (
short ti = 0; ti < BM; ti += TM) {
 
  559        for (
short j = 0; j < vec_size; j++) {
 
  560          D[ti * params->ldd + j] = T(0.);
 
  564      short jmax = tgp_bn - bj;
 
  565      jmax = jmax < vec_size ? jmax : vec_size;
 
  566      for (
short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
 
  567        for (
short j = 0; j < jmax; j++) {
 
  568          D[ti * params->ldd + j] = T(0.);
 
  576  threadgroup_barrier(mem_flags::mem_none);
 
  579  thread 
typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
 
  581  int gemm_k_iterations = params->gemm_k_iterations_aligned;
 
  583  threadgroup T As[gemm_kernel::tgp_mem_size_a];
 
  584  threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
 
  587  thread 
typename gemm_kernel::loader_a_t loader_a(
 
  588      A, params->lda, As, simd_group_id, simd_lane_id);
 
  589  thread 
typename gemm_kernel::loader_b_t loader_b(
 
  590      B, params->ldb, Bs, simd_group_id, simd_lane_id);
 
  595    for (
int k = 0; k < gemm_k_iterations; k++) {
 
  596      threadgroup_barrier(mem_flags::mem_threadgroup);
 
  598      if (!has_operand_mask ||
 
  600               [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
 
  602               [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
 
  604        loader_a.load_unsafe();
 
  605        loader_b.load_unsafe();
 
  607        threadgroup_barrier(mem_flags::mem_threadgroup);
 
  618    threadgroup_barrier(mem_flags::mem_none);
 
  622      if (!has_operand_mask ||
 
  624               [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
 
  626               [(params->K / BM) * mask_strides[5] +
 
  627                tid_x * mask_strides[4]])) {
 
  628        int lbk = params->K - params->gemm_k_iterations_aligned * BK;
 
  629        short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
 
  630        short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
 
  632        loader_a.load_safe(tile_dims_A);
 
  633        loader_b.load_safe(tile_dims_B);
 
  635        threadgroup_barrier(mem_flags::mem_threadgroup);
 
  642    mma_op.store_result(D, params->ldd);
 
  649    short tgp_bm = 
min(BM, params->M - c_row);
 
  650    short tgp_bn = 
min(BN, params->N - c_col);
 
  651    short lbk = params->K - params->gemm_k_iterations_aligned * BK;
 
  653    bool M_aligned = (tgp_bm == BM);
 
  654    bool N_aligned = (tgp_bn == BN);
 
  656    short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
 
  657    short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
 
  659    for (
int k = 0; k < gemm_k_iterations; k++) {
 
  660      threadgroup_barrier(mem_flags::mem_threadgroup);
 
  661      if (!has_operand_mask ||
 
  663               [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
 
  665               [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
 
  668          loader_a.load_unsafe();
 
  670          loader_a.load_safe(tile_dims_A);
 
  674          loader_b.load_unsafe();
 
  676          loader_b.load_safe(tile_dims_B);
 
  679        threadgroup_barrier(mem_flags::mem_threadgroup);
 
  691      threadgroup_barrier(mem_flags::mem_threadgroup);
 
  693      if (!has_operand_mask ||
 
  695               [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
 
  697               [(params->K / BM) * mask_strides[5] +
 
  698                tid_x * mask_strides[4]])) {
 
  699        short2 tile_dims_A_last =
 
  700            transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
 
  701        short2 tile_dims_B_last =
 
  702            transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
 
  704        loader_a.load_safe(tile_dims_A_last);
 
  705        loader_b.load_safe(tile_dims_B_last);
 
  707        threadgroup_barrier(mem_flags::mem_threadgroup);
 
  713    if (M_aligned && N_aligned) {
 
  714      mma_op.store_result(D, params->ldd);
 
  716      mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));