Go to the source code of this file.
|
template<typename T , typename U , int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned> |
void | gemm_splitk (const device T *A, const device T *B, device U *C, const constant GEMMSpiltKParams *params, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid) |
|
template<typename AccT , typename OutT , typename Epilogue = TransformNone<OutT, AccT>> |
void | gemm_splitk_accum (const device AccT *C_split, device OutT *D, const constant int &k_partitions, const constant int &partition_stride, const constant int &ldd, uint2 gid) |
|
template<typename AccT , typename OutT , typename Epilogue = TransformAxpby<OutT, AccT>> |
void | gemm_splitk_accum_axpby (const device AccT *C_split, device OutT *D, const constant int &k_partitions, const constant int &partition_stride, const constant int &ldd, const device OutT *C, const constant int &ldc, const constant int &fdc, const constant float &alpha, const constant float &beta, uint2 gid) |
|
◆ gemm_splitk()
template<typename T , typename U , int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned>
void gemm_splitk |
( |
const device T * | A, |
|
|
const device T * | B, |
|
|
device U * | C, |
|
|
const constant GEMMSpiltKParams * | params, |
|
|
uint | simd_lane_id, |
|
|
uint | simd_group_id, |
|
|
uint3 | tid, |
|
|
uint3 | lid ) |
◆ gemm_splitk_accum()
template<typename AccT , typename OutT , typename Epilogue = TransformNone<OutT, AccT>>
void gemm_splitk_accum |
( |
const device AccT * | C_split, |
|
|
device OutT * | D, |
|
|
const constant int & | k_partitions, |
|
|
const constant int & | partition_stride, |
|
|
const constant int & | ldd, |
|
|
uint2 | gid ) |
◆ gemm_splitk_accum_axpby()
template<typename AccT , typename OutT , typename Epilogue = TransformAxpby<OutT, AccT>>
void gemm_splitk_accum_axpby |
( |
const device AccT * | C_split, |
|
|
device OutT * | D, |
|
|
const constant int & | k_partitions, |
|
|
const constant int & | partition_stride, |
|
|
const constant int & | ldd, |
|
|
const device OutT * | C, |
|
|
const constant int & | ldc, |
|
|
const constant int & | fdc, |
|
|
const constant float & | alpha, |
|
|
const constant float & | beta, |
|
|
uint2 | gid ) |