#include <gemv_masked.h>
|
static METAL_FUNC void | load_unsafe (const device T *src, thread T dst[TN], const int src_offset=0) |
|
static METAL_FUNC void | load_safe (const device T *src, thread T dst[TN], const int src_offset=0, const int src_size=TN) |
|
static METAL_FUNC void | run (const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &matrix_ld, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, threadgroup T *tgp_memory, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid) |
|
◆ load_safe()
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
static METAL_FUNC void GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::load_safe |
( |
const device T * | src, |
|
|
thread T | dst[TN], |
|
|
const int | src_offset = 0, |
|
|
const int | src_size = TN ) |
|
inlinestatic |
◆ load_unsafe()
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
static METAL_FUNC void GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::load_unsafe |
( |
const device T * | src, |
|
|
thread T | dst[TN], |
|
|
const int | src_offset = 0 ) |
|
inlinestatic |
◆ run()
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
static METAL_FUNC void GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::run |
( |
const device T * | mat, |
|
|
const device T * | in_vec, |
|
|
device T * | out_vec, |
|
|
const constant int & | in_vec_size, |
|
|
const constant int & | out_vec_size, |
|
|
const constant int & | matrix_ld, |
|
|
const device out_mask_t * | out_mask, |
|
|
const device op_mask_t * | mat_mask, |
|
|
const device op_mask_t * | vec_mask, |
|
|
const constant int * | mask_strides, |
|
|
threadgroup T * | tgp_memory, |
|
|
uint3 | tid, |
|
|
uint3 | lid, |
|
|
uint | simd_gid, |
|
|
uint | simd_lid ) |
|
inlinestatic |
◆ blockM
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const int GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::blockM = threadsM * TM |
|
staticconstexpr |
◆ blockN
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const int GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::blockN = threadsN * TN |
|
staticconstexpr |
◆ has_mul_operand_mask
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const bool GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::has_mul_operand_mask |
|
staticconstexpr |
Initial value:=
static constant constexpr const bool has_operand_mask
Definition gemv_masked.h:63
◆ has_mul_output_mask
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const bool GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::has_mul_output_mask |
|
staticconstexpr |
Initial value:=
static constant constexpr const bool has_output_mask
Definition gemv_masked.h:64
◆ has_operand_mask
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const bool GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t> |
|
staticconstexpr |
◆ has_output_mask
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const bool GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::has_output_mask = !metal::is_same_v<out_mask_t, nomask_t> |
|
staticconstexpr |
◆ needs_tgp_reduction
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const bool GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::needs_tgp_reduction = BN > 1 |
|
staticconstexpr |
◆ tgp_mem_size
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const short GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0 |
|
staticconstexpr |
◆ threadsM
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const int GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::threadsM = BM * SM |
|
staticconstexpr |
◆ threadsN
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const int GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::threadsN = BN * SN |
|
staticconstexpr |
The documentation for this struct was generated from the following file: