Go to the source code of this file.
 | 
| template<typename T , int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, typename AccumType  = float>  | 
| void  | gemm (const device T *A, const device T *B, const device T *C, device T *D, const constant GEMMParams *params, const constant GEMMAddMMParams *addmm_params, const constant int *batch_shape, const constant size_t *batch_strides, const constant uint32_t *lhs_indices, const constant uint32_t *rhs_indices, const constant uint32_t *C_indices, const constant int *operand_shape, const constant size_t *operand_strides, const constant packed_int3 &operand_batch_ndim, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid) | 
|   | 
◆ gemm()
template<typename T , int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, typename AccumType  = float> 
      
        
          | void gemm  | 
          ( | 
          const device T * |           A,  | 
        
        
           | 
           | 
          const device T * |           B,  | 
        
        
           | 
           | 
          const device T * |           C,  | 
        
        
           | 
           | 
          device T * |           D,  | 
        
        
           | 
           | 
          const constant GEMMParams * |           params,  | 
        
        
           | 
           | 
          const constant GEMMAddMMParams * |           addmm_params,  | 
        
        
           | 
           | 
          const constant int * |           batch_shape,  | 
        
        
           | 
           | 
          const constant size_t * |           batch_strides,  | 
        
        
           | 
           | 
          const constant uint32_t * |           lhs_indices,  | 
        
        
           | 
           | 
          const constant uint32_t * |           rhs_indices,  | 
        
        
           | 
           | 
          const constant uint32_t * |           C_indices,  | 
        
        
           | 
           | 
          const constant int * |           operand_shape,  | 
        
        
           | 
           | 
          const constant size_t * |           operand_strides,  | 
        
        
           | 
           | 
          const constant packed_int3 & |           operand_batch_ndim,  | 
        
        
           | 
           | 
          uint |           simd_lane_id,  | 
        
        
           | 
           | 
          uint |           simd_group_id,  | 
        
        
           | 
           | 
          uint3 |           tid,  | 
        
        
           | 
           | 
          uint3 |           lid ) | 
        
      
 
 
◆ align_K
◆ align_M
◆ align_N
◆ do_axpby
◆ do_gather
◆ gather_bias
◆ has_batch
◆ use_out_source
      
        
          | constant bool use_out_source |