#include <gemv_masked.h>
 | 
| static METAL_FUNC void  | load_unsafe (const device T *src, thread T dst[TN], const int src_offset=0) | 
|   | 
| static METAL_FUNC void  | load_safe (const device T *src, thread T dst[TN], const int src_offset=0, const int src_size=TN) | 
|   | 
| static METAL_FUNC void  | run (const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &matrix_ld, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, threadgroup T *tgp_memory, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid) | 
|   | 
◆ load_safe()
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN> 
  
  
      
        
          | static METAL_FUNC void GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::load_safe  | 
          ( | 
          const device T * |           src,  | 
         
        
           | 
           | 
          thread T |           dst[TN],  | 
         
        
           | 
           | 
          const int |           src_offset = 0,  | 
         
        
           | 
           | 
          const int |           src_size = TN ) | 
         
       
   | 
  
inlinestatic   | 
  
 
 
◆ load_unsafe()
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN> 
  
  
      
        
          | static METAL_FUNC void GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::load_unsafe  | 
          ( | 
          const device T * |           src,  | 
         
        
           | 
           | 
          thread T |           dst[TN],  | 
         
        
           | 
           | 
          const int |           src_offset = 0 ) | 
         
       
   | 
  
inlinestatic   | 
  
 
 
◆ run()
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN> 
  
  
      
        
          | static METAL_FUNC void GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::run  | 
          ( | 
          const device T * |           mat,  | 
         
        
           | 
           | 
          const device T * |           in_vec,  | 
         
        
           | 
           | 
          device T * |           out_vec,  | 
         
        
           | 
           | 
          const constant int & |           in_vec_size,  | 
         
        
           | 
           | 
          const constant int & |           out_vec_size,  | 
         
        
           | 
           | 
          const constant int & |           matrix_ld,  | 
         
        
           | 
           | 
          const device out_mask_t * |           out_mask,  | 
         
        
           | 
           | 
          const device op_mask_t * |           mat_mask,  | 
         
        
           | 
           | 
          const device op_mask_t * |           vec_mask,  | 
         
        
           | 
           | 
          const constant int * |           mask_strides,  | 
         
        
           | 
           | 
          threadgroup T * |           tgp_memory,  | 
         
        
           | 
           | 
          uint3 |           tid,  | 
         
        
           | 
           | 
          uint3 |           lid,  | 
         
        
           | 
           | 
          uint |           simd_gid,  | 
         
        
           | 
           | 
          uint |           simd_lid ) | 
         
       
   | 
  
inlinestatic   | 
  
 
 
◆ blockM
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN> 
  
  
      
        
          | constant constexpr const int GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::blockM = threadsM * TM | 
         
       
   | 
  
staticconstexpr   | 
  
 
 
◆ blockN
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN> 
  
  
      
        
          | constant constexpr const int GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::blockN = threadsN * TN | 
         
       
   | 
  
staticconstexpr   | 
  
 
 
◆ has_mul_operand_mask
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN> 
  
  
      
        
          | constant constexpr const bool GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::has_mul_operand_mask | 
         
       
   | 
  
staticconstexpr   | 
  
 
Initial value:=
static constant constexpr const bool has_operand_mask
Definition gemv_masked.h:63
 
 
 
 
◆ has_mul_output_mask
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN> 
  
  
      
        
          | constant constexpr const bool GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::has_mul_output_mask | 
         
       
   | 
  
staticconstexpr   | 
  
 
Initial value:=
static constant constexpr const bool has_output_mask
Definition gemv_masked.h:64
 
 
 
 
◆ has_operand_mask
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN> 
  
  
      
        
          | constant constexpr const bool GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t> | 
         
       
   | 
  
staticconstexpr   | 
  
 
 
◆ has_output_mask
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN> 
  
  
      
        
          | constant constexpr const bool GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::has_output_mask = !metal::is_same_v<out_mask_t, nomask_t> | 
         
       
   | 
  
staticconstexpr   | 
  
 
 
◆ needs_tgp_reduction
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN> 
  
  
      
        
          | constant constexpr const bool GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::needs_tgp_reduction = BN > 1 | 
         
       
   | 
  
staticconstexpr   | 
  
 
 
◆ tgp_mem_size
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN> 
  
  
      
        
          | constant constexpr const short GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0 | 
         
       
   | 
  
staticconstexpr   | 
  
 
 
◆ threadsM
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN> 
  
  
      
        
          | constant constexpr const int GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::threadsM = BM * SM | 
         
       
   | 
  
staticconstexpr   | 
  
 
 
◆ threadsN
template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN> 
  
  
      
        
          | constant constexpr const int GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::threadsN = BN * SN | 
         
       
   | 
  
staticconstexpr   | 
  
 
 
The documentation for this struct was generated from the following file: