|
template<typename T , typename out_mask_t , typename op_mask_t , int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned> |
void | block_masked_gemm (const device T *A, const device T *B, device T *D, const constant GEMMParams *params, const constant int *batch_shape, const constant size_t *batch_strides, const device out_mask_t *out_mask, const device op_mask_t *lhs_mask, const device op_mask_t *rhs_mask, const constant int *mask_strides, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid) |
|
template<typename T , int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned, bool has_operand_mask = false> |
void | block_masked_gemm (const device T *A, const device T *B, device T *D, const constant GEMMParams *params, const constant int *batch_shape, const constant size_t *batch_strides, const device bool *out_mask, const device bool *lhs_mask, const device bool *rhs_mask, const constant int *mask_strides, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 lid) |
|
template<typename T , int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned, bool has_operand_mask = false>
void block_masked_gemm |
( |
const device T * | A, |
|
|
const device T * | B, |
|
|
device T * | D, |
|
|
const constant GEMMParams * | params, |
|
|
const constant int * | batch_shape, |
|
|
const constant size_t * | batch_strides, |
|
|
const device bool * | out_mask, |
|
|
const device bool * | lhs_mask, |
|
|
const device bool * | rhs_mask, |
|
|
const constant int * | mask_strides, |
|
|
uint | simd_lane_id, |
|
|
uint | simd_group_id, |
|
|
uint3 | tid, |
|
|
uint3 | lid ) |
template<typename T , typename out_mask_t , typename op_mask_t , int BM, int BN, int BK, int WM, int WN, bool transpose_a, bool transpose_b, bool MN_aligned, bool K_aligned>
void block_masked_gemm |
( |
const device T * | A, |
|
|
const device T * | B, |
|
|
device T * | D, |
|
|
const constant GEMMParams * | params, |
|
|
const constant int * | batch_shape, |
|
|
const constant size_t * | batch_strides, |
|
|
const device out_mask_t * | out_mask, |
|
|
const device op_mask_t * | lhs_mask, |
|
|
const device op_mask_t * | rhs_mask, |
|
|
const constant int * | mask_strides, |
|
|
uint | simd_lane_id, |
|
|
uint | simd_group_id, |
|
|
uint3 | tid, |
|
|
uint3 | lid ) |