Go to the source code of this file.
|
template<typename T , int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, typename AccumType = float> |
void | gemm (const device T *A, const device T *B, const device T *C, device T *D, const constant GEMMParams *params, const constant GEMMAddMMParams *addmm_params, const constant int *batch_shape, const constant size_t *batch_strides, const constant uint32_t *lhs_indices, const constant uint32_t *rhs_indices, const constant uint32_t *C_indices, const constant int *operand_shape, const constant size_t *operand_strides, const constant packed_int3 &operand_batch_ndim, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid) |
|
◆ gemm()
template<typename T , int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, typename AccumType = float>
void gemm |
( |
const device T * | A, |
|
|
const device T * | B, |
|
|
const device T * | C, |
|
|
device T * | D, |
|
|
const constant GEMMParams * | params, |
|
|
const constant GEMMAddMMParams * | addmm_params, |
|
|
const constant int * | batch_shape, |
|
|
const constant size_t * | batch_strides, |
|
|
const constant uint32_t * | lhs_indices, |
|
|
const constant uint32_t * | rhs_indices, |
|
|
const constant uint32_t * | C_indices, |
|
|
const constant int * | operand_shape, |
|
|
const constant size_t * | operand_strides, |
|
|
const constant packed_int3 & | operand_batch_ndim, |
|
|
uint | simd_lane_id, |
|
|
uint | simd_group_id, |
|
|
uint3 | tid, |
|
|
uint3 | lid ) |
◆ align_K
◆ align_M
◆ align_N
◆ do_axpby
◆ do_gather
◆ gather_bias
◆ has_batch
◆ use_out_source
constant bool use_out_source |