MLX
Loading...
Searching...
No Matches
Static Public Member Functions | Static Public Attributes | List of all members
GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN > Struct Template Reference

#include <gemv_masked.h>

Static Public Member Functions

static METAL_FUNC void load_unsafe (const device T *src, thread T dst[TN], const int src_offset=0)
 
static METAL_FUNC void load_safe (const device T *src, thread T dst[TN], const int src_offset=0, const int src_size=TN)
 
static METAL_FUNC void run (const device T *mat, const device T *in_vec, device T *out_vec, const constant int &in_vec_size, const constant int &out_vec_size, const constant int &matrix_ld, const device out_mask_t *out_mask, const device op_mask_t *mat_mask, const device op_mask_t *vec_mask, const constant int *mask_strides, threadgroup T *tgp_memory, uint3 tid, uint3 lid, uint simd_gid, uint simd_lid)
 

Static Public Attributes

static constant constexpr const int threadsM = BM * SM
 
static constant constexpr const int threadsN = BN * SN
 
static constant constexpr const int blockM = threadsM * TM
 
static constant constexpr const int blockN = threadsN * TN
 
static constant constexpr const bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>
 
static constant constexpr const bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>
 
static constant constexpr const bool has_mul_operand_mask
 
static constant constexpr const bool has_mul_output_mask
 
static constant constexpr const short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0
 
static constant constexpr const bool needs_tgp_reduction = BN > 1
 

Member Function Documentation

◆ load_safe()

template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
static METAL_FUNC void GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::load_safe ( const device T * src,
thread T dst[TN],
const int src_offset = 0,
const int src_size = TN )
inlinestatic

◆ load_unsafe()

template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
static METAL_FUNC void GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::load_unsafe ( const device T * src,
thread T dst[TN],
const int src_offset = 0 )
inlinestatic

◆ run()

template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
static METAL_FUNC void GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::run ( const device T * mat,
const device T * in_vec,
device T * out_vec,
const constant int & in_vec_size,
const constant int & out_vec_size,
const constant int & matrix_ld,
const device out_mask_t * out_mask,
const device op_mask_t * mat_mask,
const device op_mask_t * vec_mask,
const constant int * mask_strides,
threadgroup T * tgp_memory,
uint3 tid,
uint3 lid,
uint simd_gid,
uint simd_lid )
inlinestatic

Member Data Documentation

◆ blockM

template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const int GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::blockM = threadsM * TM
staticconstexpr

◆ blockN

template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const int GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::blockN = threadsN * TN
staticconstexpr

◆ has_mul_operand_mask

template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const bool GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::has_mul_operand_mask
staticconstexpr
Initial value:
=
has_operand_mask && !metal::is_same_v<op_mask_t, bool>
static constant constexpr const bool has_operand_mask
Definition gemv_masked.h:63

◆ has_mul_output_mask

template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const bool GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::has_mul_output_mask
staticconstexpr
Initial value:
=
has_output_mask && !metal::is_same_v<out_mask_t, bool>
static constant constexpr const bool has_output_mask
Definition gemv_masked.h:64

◆ has_operand_mask

template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const bool GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>
staticconstexpr

◆ has_output_mask

template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const bool GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>
staticconstexpr

◆ needs_tgp_reduction

template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const bool GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::needs_tgp_reduction = BN > 1
staticconstexpr

◆ tgp_mem_size

template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const short GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0
staticconstexpr

◆ threadsM

template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const int GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::threadsM = BM * SM
staticconstexpr

◆ threadsN

template<typename T , typename out_mask_t , typename op_mask_t , const int BM, const int BN, const int SM, const int SN, const int TM, const int TN>
constant constexpr const int GEMVKernel< T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN >::threadsN = BN * SN
staticconstexpr

The documentation for this struct was generated from the following file: