4template [[host_name("{name}")]] 
    5[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>( 
    6    const device {itype} *A [[buffer(0)]], 
    7    const device {itype} *B [[buffer(1)]], 
    8    const device {itype} *C [[buffer(2), function_constant(use_out_source)]], 
    9    device {itype} *D [[buffer(3)]], 
   10    const constant GEMMParams* params [[buffer(4)]], 
   11    const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], 
   12    const constant int* batch_shape [[buffer(6)]], 
   13    const constant size_t* batch_strides [[buffer(7)]], 
   14    const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], 
   15    const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], 
   16    const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], 
   17    const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], 
   18    const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], 
   19    const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], 
   20    uint simd_lane_id [[thread_index_in_simdgroup]], 
   21    uint simd_group_id [[simdgroup_index_in_threadgroup]], 
   22    uint3 tid [[threadgroup_position_in_grid]], 
   23    uint3 lid [[thread_position_in_threadgroup]]); 
   27template [[host_name("{name}")]] [[kernel]] void 
   41    const device {itype}* A [[buffer(0)]], 
   42    const device {itype}* B [[buffer(1)]], 
   43    device {itype}* D [[buffer(3)]], 
   44    const constant GEMMParams* params [[buffer(4)]], 
   45    const constant int* batch_shape [[buffer(6)]], 
   46    const constant size_t* batch_strides [[buffer(7)]], 
   47    const device {outmasktype}* out_mask [[buffer(10)]], 
   48    const device {opmasktype}* lhs_mask [[buffer(11)]], 
   49    const device {opmasktype}* rhs_mask [[buffer(12)]], 
   50    const constant int* mask_strides [[buffer(13)]], 
   51    uint simd_lane_id [[thread_index_in_simdgroup]], 
   52    uint simd_group_id [[simdgroup_index_in_threadgroup]], 
   53    uint3 tid [[threadgroup_position_in_grid]], 
   54    uint3 lid [[thread_position_in_threadgroup]]); 
   58template [[host_name("{name}")]] [[kernel]] void 
   71    const device {itype}* A [[buffer(0)]], 
   72    const device {itype}* B [[buffer(1)]], 
   73    device {otype}* C [[buffer(2)]], 
   74    const constant GEMMSpiltKParams* params [[buffer(3)]], 
   75    uint simd_lane_id [[thread_index_in_simdgroup]], 
   76    uint simd_group_id [[simdgroup_index_in_threadgroup]], 
   77    uint3 tid [[threadgroup_position_in_grid]], 
   78    uint3 lid [[thread_position_in_threadgroup]]); 
   82template [[host_name("{name}")]] [[kernel]] void 
   83gemm_splitk_accum<{atype}, {otype}>( 
   84    const device {atype}* C_split [[buffer(0)]], 
   85    device {otype}* D [[buffer(1)]], 
   86    const constant int& k_partitions [[buffer(2)]], 
   87    const constant int& partition_stride [[buffer(3)]], 
   88    const constant int& ldd [[buffer(4)]], 
   89    uint2 gid [[thread_position_in_grid]]); 
   93template [[host_name("{name}")]] [[kernel]] void 
   94gemm_splitk_accum_axpby<{atype}, {otype}>( 
   95    const device {atype}* C_split [[buffer(0)]], 
   96    device {otype}* D [[buffer(1)]], 
   97    const constant int& k_partitions [[buffer(2)]], 
   98    const constant int& partition_stride [[buffer(3)]], 
   99    const constant int& ldd [[buffer(4)]], 
  100    const device {otype}* C [[buffer(5)]], 
  101    const constant int& ldc [[buffer(6)]], 
  102    const constant int& fdc [[buffer(7)]], 
  103    const constant float& alpha [[buffer(8)]], 
  104    const constant float& beta [[buffer(9)]], 
  105    uint2 gid [[thread_position_in_grid]]); 
constexpr std::string_view steel_gemm_splitk_accum_kernels
Definition steel_gemm.h:81
 
constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels
Definition steel_gemm.h:92
 
constexpr std::string_view steel_gemm_fused_kernels
Definition steel_gemm.h:3
 
constexpr std::string_view steel_gemm_masked_kernels
Definition steel_gemm.h:26
 
constexpr std::string_view steel_gemm_splitk_kernels
Definition steel_gemm.h:57